In [None]:
!pip install -U networkit

## Modules

In [None]:
def get_path(*args, **kwargs):
    return '/content/recohut/datasets/ml-100k'

### Utils

#### Torch utils

In [None]:
import torch

is_cuda_available = torch.cuda.is_available()

if is_cuda_available: 
    print("Using CUDA...\n")
    LongTensor = torch.cuda.LongTensor
    FloatTensor = torch.cuda.FloatTensor
    BoolTensor = torch.cuda.BoolTensor
else:
    LongTensor = torch.LongTensor
    FloatTensor = torch.FloatTensor
    BoolTensor = torch.BoolTensor

# def get_model_class(hyper_params):
#     from pytorch_models import MF, MVAE, SASRec, SVAE

#     return {
#         "bias_only": MF.MF,
#         "MF_dot": MF.MF,
#         "MF": MF.MF,
#         "MVAE": MVAE.MVAE,
#         "SVAE": SVAE.SVAE,
#         "SASRec": SASRec.SASRec,
#     }[hyper_params['model_type']]

def get_model_class(hyper_params):

    return {
        "bias_only":MF,
        "MF_dot": MF,
        "MF": MF,
    }[hyper_params['model_type']]


def xavier_init(model):
    for _, param in model.named_parameters():
        try: torch.nn.init.xavier_uniform_(param.data)
        except: pass # just ignore those failed init layers

#### Utils

In [None]:
INF = float(1e6)

# def get_data_loader_class(hyper_params):
#     from data_loaders import MF, MVAE, SASRec, SVAE

#     return {
#         "pop_rec": (MF.TrainDataset, MF.TestDataset),
#         "bias_only": (MF.TrainDataset, MF.TestDataset),
#         "MF_dot": (MF.TrainDataset, MF.TestDataset),
#         "MF": (MF.TrainDataset, MF.TestDataset),
#         "NeuMF": (MF.TrainDataset, MF.TestDataset),
#         "MVAE": (MVAE.TrainDataset, MVAE.TestDataset),
#         "SVAE": (SVAE.TrainDataset, SVAE.TestDataset),
#         "SASRec": (SASRec.TrainDataset, SASRec.TestDataset),
#     }[hyper_params['model_type']]

def get_data_loader_class(hyper_params):

    return {
        "pop_rec": (TrainDataset, TestDataset),
        "bias_only": (TrainDataset, TestDataset),
        "MF_dot": (TrainDataset, TestDataset),
        "MF": (TrainDataset, TestDataset),
        "NeuMF": (TrainDataset, TestDataset),
        "MVAE": (TrainDataset, TestDataset),
        "SVAE": (TrainDataset, TestDataset),
        "SASRec": (TrainDataset, TestDataset),
    }[hyper_params['model_type']]

def valid_hyper_params(hyper_params):
    ## Check if the methods and task match
    valid_tasks = {
        "pop_rec":      [             'implicit', 'sequential' ],
        "bias_only":    [ 'explicit', 'implicit', 'sequential' ],
        "MF_dot":       [ 'explicit', 'implicit', 'sequential' ],
        "MF":           [ 'explicit', 'implicit', 'sequential' ],
        "NeuMF":        [ 'explicit', 'implicit', 'sequential' ],
        "MVAE":         [             'implicit', 'sequential' ],
        "SVAE":         [                         'sequential' ],
        "SASRec":       [                         'sequential' ],
    }[hyper_params['model_type']]

    return hyper_params['task'] in valid_tasks

def get_common_path(hyper_params, star_match = False):
    ## E.g. Running SASRec on an explicit/implicit feedback task.
    if not valid_hyper_params(hyper_params): return None

    # To avoid writing hyper_params[key] everytime
    def get(key): 
        if star_match: return hyper_params.get(key, ".*")
        return hyper_params[key]

    common_path = "{}_{}".format(get('dataset'), get('task'))

    if get('sampling')[:3] == 'svp':
        common_path += "_{}_{}_perc_{}".format(get('sampling'), get('sampling_percent'), get('sampling_svp'))
    elif get('sampling') == 'complete_data': common_path += "_complete_data"
    else: common_path += "_{}_perc_{}".format(get('sampling_percent'), get('sampling'))
    
    common_path += "_{}".format(get('model_type')) + {
        ".*":        lambda: "",
        "pop_rec":   lambda: "",
        "bias_only": lambda: "",
        "MF_dot":    lambda: "_latent_size_{}_dropout_{}".format(get('latent_size'), get('dropout')),
        "MF":        lambda: "_latent_size_{}_dropout_{}".format(get('latent_size'), get('dropout')),
        "NeuMF":     lambda: "_latent_size_{}_dropout_{}".format(get('latent_size'), get('dropout')),
        "MVAE":      lambda: "_latent_size_{}_dropout_{}".format(get('latent_size'), get('dropout')),
        "SVAE":      lambda: "_latent_size_{}_dropout_{}_next_{}".format(get('latent_size'), get('dropout'), get('num_next')),
        "SASRec":    lambda: "_latent_size_{}_dropout_{}_heads_{}_blocks_{}".format(get('latent_size'), get('dropout'), get('num_heads'), get('num_blocks')),
    }[get('model_type')]() # lambda to ensure lazy evaluation

    if get('task') in [ 'implicit', 'sequential' ]:
        common_path += "_trn_negs_{}_tst_negs_{}".format(get('num_train_negs'), get('num_test_negs'))

    common_path += "_wd_{}_lr_{}".format(get('weight_decay'), get('lr'))

    return common_path

def remap_items(data):
    item_map = {}
    for user_data in data:
        for item, rating, time in user_data:
            if item not in item_map: item_map[item] = len(item_map) + 1

    for u in range(len(data)):
        data[u] = list(map(lambda x: [ item_map[x[0]], x[1], x[2] ], data[u]))

    return data

def file_write(log_file, s, dont_print=False):
    if dont_print == False: print(s)
    f = open(log_file, 'a')
    f.write(s+'\n')
    f.close()

def clear_log_file(log_file):
    f = open(log_file, 'w')
    f.write('')
    f.close()

def pretty_print(h):
    print("{")
    for key in h:
        print(' ' * 4 + str(key) + ': ' + h[key])
    print('}\n')

def log_end_epoch(hyper_params, metrics, epoch, time_elpased, metrics_on = '(VAL)', dont_print = False):
    string2 = ""
    for m in metrics: string2 += " | " + m + ' = ' + str(metrics[m])
    string2 += ' ' + metrics_on

    ss  = '-' * 89
    ss += '\n| end of epoch {} | time = {:5.2f}'.format(epoch, time_elpased)
    ss += string2
    ss += '\n'
    ss += '-' * 89
    file_write(hyper_params['log_file'], ss, dont_print = dont_print)

#### Paths

In [None]:
import os

# Sampling experiments' constants
BASE_SAMPLING_PATH = "./experiments/sampling_runs/"

# Data-genie experiments' constants
BASE_DATA_GENIE_PATH = "./experiments/data_genie/"

def get_svp_log_file_path(hyper_params):
    return BASE_SAMPLING_PATH + "/results/logs/SVP/{}.txt".format(get_common_path(hyper_params))

def get_svp_model_file_path(hyper_params):
    return BASE_SAMPLING_PATH + "/results/models/SVP/{}.pt".format(get_common_path(hyper_params))

def get_log_base_path():
	return BASE_SAMPLING_PATH + "/results/logs/trained/"

def get_log_file_path(hyper_params):
	return get_log_base_path() + get_common_path(hyper_params) + ".txt"

def get_model_file_path(hyper_params):
	return BASE_SAMPLING_PATH + "/results/models/trained/" + get_common_path(hyper_params) + ".pt"

def get_data_path(hyper_params):
    dataset = hyper_params
    if type(dataset) != str: dataset = hyper_params['dataset']
    return "./datasets/{}/".format(dataset)

def get_index_path(hyper_params):
    train_test_split = {
        'explicit':    '20_percent_hist',
        'implicit':    '20_percent_hist',
        'sequential':  'leave_2',
    }[hyper_params['task']]

    ret = get_data_path(hyper_params['dataset']) + "/{}/".format(train_test_split)

    if hyper_params['sampling'][:3] == 'svp':
        ret += "{}_{}/{}_perc_{}/".format(
            hyper_params['sampling'], hyper_params['task'],
            hyper_params['sampling_percent'], hyper_params['sampling_svp']
        )
    elif hyper_params['sampling'] == 'complete_data':
        ret += "complete_data/"
    else:
        ret += "{}_perc_{}/".format(
            hyper_params['sampling_percent'], hyper_params['sampling'],
        )

    return ret

### Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# from torch_utils import LongTensor, FloatTensor

class BaseMF(nn.Module):
    def __init__(self, hyper_params, keep_gamma = True):
        super(BaseMF, self).__init__()
        self.hyper_params = hyper_params

        # Declaring alpha, beta, gamma
        self.global_bias = nn.Parameter(FloatTensor([ 4.0 if hyper_params['task'] == 'explicit' else 0.5 ]))
        self.user_bias = nn.Parameter(FloatTensor([ 0.0 for _ in range(hyper_params['total_users']) ]))
        self.item_bias = nn.Parameter(FloatTensor([ 0.0 for _ in range(hyper_params['total_items']) ]))
        if keep_gamma:
            self.user_embedding = nn.Embedding(hyper_params['total_users'], hyper_params['latent_size'])
            self.item_embedding = nn.Embedding(hyper_params['total_items'], hyper_params['latent_size'])

        # For faster evaluation
        self.all_items_vector = LongTensor(
            list(range(hyper_params['total_items']))
        )

    def get_score(self, data):
        pass # Virtual function, implement in all sub-classes

    def forward(self, data, eval = False):
        user_id, pos_item_id, neg_items = data

        # Evaluation -- Rank all items
        if pos_item_id is None: 
            ret = []
            for b in range(user_id.shape[0]):
                ret.append(self.get_score(
                    user_id[b].unsqueeze(-1).repeat(1, self.hyper_params['total_items']).view(-1), 
                    self.all_items_vector.view(-1)
                ).view(1, -1))
            return torch.cat(ret)
        
        # Explicit feedback
        if neg_items is None: return self.get_score(user_id, pos_item_id.squeeze(-1))
        
        # Implicit feedback
        return self.get_score(
            user_id.unsqueeze(-1).repeat(1, pos_item_id.shape[1]).view(-1), 
            pos_item_id.view(-1)
        ).view(pos_item_id.shape), self.get_score(
            user_id.unsqueeze(-1).repeat(1, neg_items.shape[1]).view(-1), 
            neg_items.view(-1)
        ).view(neg_items.shape)

class MF(BaseMF):
    def __init__(self, hyper_params):
        keep_gamma = hyper_params['model_type'] != 'bias_only'

        super(MF, self).__init__(hyper_params, keep_gamma = keep_gamma)
        if keep_gamma: self.dropout = nn.Dropout(hyper_params['dropout'])

        if hyper_params['model_type'] == 'MF':
            latent_size = hyper_params['latent_size']

            self.projection = nn.Sequential(
                nn.Dropout(hyper_params['dropout']),
                nn.Linear(2 * latent_size, latent_size),
                nn.ReLU(),
                nn.Linear(latent_size, latent_size)
            )
            for m in self.projection:
                if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight)

            self.final = nn.Linear(2 * latent_size, 1)
            self.sigmoid = nn.Sigmoid()
            self.relu = nn.ReLU()

    def get_score(self, user_id, item_id):
        # For the FM
        user_bias = self.user_bias.gather(0, user_id.view(-1)).view(user_id.shape)
        item_bias = self.item_bias.gather(0, item_id.view(-1)).view(item_id.shape)

        if self.hyper_params['model_type'] == 'bias_only': 
            return user_bias + item_bias + self.global_bias

        # Embed Latent space
        user = self.dropout(self.user_embedding(user_id.view(-1))) # [bsz x 32]
        item = self.dropout(self.item_embedding(item_id.view(-1))) # [bsz x 32]

        # Dot product
        if self.hyper_params['model_type'] == 'MF_dot':
            rating = torch.sum(user * item, dim = -1).view(user_id.shape)
            return user_bias + item_bias + self.global_bias + rating

        mf_vector = user * item
        cat = torch.cat([ user, item ], dim = -1)
        mlp_vector = self.projection(cat)

        # Concatenate and get single score
        cat = torch.cat([ mlp_vector, mf_vector ], dim = -1)
        rating = self.final(cat)[:, 0].view(user_id.shape) # [bsz]

        return user_bias + item_bias + self.global_bias + rating

### Graph Sampler

#### ForestFire

In [None]:
import random
import numpy as np
from collections import deque

class ForestFireSampler:
    """An implementation of forest fire sampling. The procedure is a stochastic
    snowball sampling method where the expansion is proportional to the burning probability. 
    `"For details about the algorithm see this paper." <https://cs.stanford.edu/people/jure/pubs/sampling-kdd06.pdf>`_
    Inspiration credit: 
        littleballoffur
        https://github.com/benedekrozemberczki/littleballoffur
    Args:
        number_of_nodes (int): Number of sampled nodes. Default is 100.
        p (float): Burning probability. Default is 0.4.
        seed (int): Random seed. Default is 42.
    """
    def __init__(self, number_of_nodes: int=100, p: float=0.4, seed: int=42, max_visited_nodes_backlog: int=100,
                 restart_hop_size: int = 10):
        self.number_of_nodes = number_of_nodes
        self.p = p
        self.seed = seed
        self._set_seed() 
        self.restart_hop_size = restart_hop_size
        self.max_visited_nodes_backlog = max_visited_nodes_backlog

    def _set_seed(self):
        random.seed(self.seed)
        np.random.seed(self.seed)

    def _create_node_sets(self, graph):
        """
        Create a starting set of nodes.
        """
        self._sampled_nodes = set()
        self._set_of_nodes = set(range(graph.number_of_nodes()))
        self._visited_nodes = deque(maxlen=self.max_visited_nodes_backlog)

    def get_neighbors(self, graph, node):
        return list(graph.neighbors(node))

    def _start_a_fire(self, graph):
        """
        Starting a forest fire from a single node.
        """
        remaining_nodes = list(self._set_of_nodes.difference(self._sampled_nodes))
        seed_node = random.choice(remaining_nodes)
        self._sampled_nodes.add(seed_node)
        node_queue = deque([seed_node])
        while len(self._sampled_nodes) < self.number_of_nodes:
            if len(node_queue) == 0:
                node_queue = deque([self._visited_nodes.popleft()
                              for k in range(min(self.restart_hop_size, len(self._visited_nodes)))])
                if len(node_queue) == 0:
                    # print('Warning: could not collect the required number of nodes. The fire could not find enough nodes to burn.')
                    break
            top_node = node_queue.popleft()
            self._sampled_nodes.add(top_node)
            neighbors = set(self.get_neighbors(graph, top_node))
            unvisited_neighbors = neighbors.difference(self._sampled_nodes)
            score = np.random.geometric(self.p)
            count = min(len(unvisited_neighbors), score)
            burned_neighbors = random.sample(unvisited_neighbors, count)
            self._visited_nodes.extendleft(unvisited_neighbors.difference(set(burned_neighbors)))
            for neighbor in burned_neighbors:
                if len(self._sampled_nodes) >= self.number_of_nodes:
                    break
                node_queue.extend([neighbor])

#### RandomWalk

In [None]:
import random

class RandomWalkWithRestartSampler:
    """An implementation of node sampling by random walks with restart. The 
    process is a discrete random walker on nodes which teleports back to the
    staring node with a fixed probability. This results in a connected subsample
    from the original input graph. `"For details about the algorithm see this 
    paper." <https://cs.stanford.edu/people/jure/pubs/sampling-kdd06.pdf>`_
    Inspiration credit: 
        littleballoffur
        https://github.com/benedekrozemberczki/littleballoffur
    Args:
        number_of_nodes (int): Number of nodes. Default is 100.
        seed (int): Random seed. Default is 42.
        p (float): Restart probability. Default is 0.1.
    """
    def __init__(self, number_of_nodes: int=100, seed: int=42, p: float=0.1):
        self.number_of_nodes = number_of_nodes
        self.seed = seed
        self.p = p
        self._set_seed()

    def _set_seed(self):
        random.seed(self.seed)

    def get_neighbors(self, graph, node):
        return list(graph.neighbors(node))

    def get_random_neighbor(self, graph, node):
        return random.choice(self.get_neighbors(graph, node))

    def get_nodes(self, graph):
        return list(graph.nodes)

    def get_number_of_nodes(self, graph):
        return graph.number_of_nodes()

    def _create_initial_node_set(self, graph, start_node):
        """
        Choosing an initial node.
        """
        self._set_of_nodes = set(self.get_nodes(graph))

        if start_node is not None:
            if start_node >= 0 and start_node < self.get_number_of_nodes(graph):
                self._current_node = start_node
                self._sampled_nodes = set([self._current_node])
            else:
                raise ValueError("Starting node index is out of range.")
        else:
            self._current_node = random.choice(range(self.get_number_of_nodes(graph)))
            self._sampled_nodes = set([self._current_node])
        self._initial_node = self._current_node

    def _do_a_step(self, graph):
        """
        Doing a single random walk step.
        """
        score = random.uniform(0, 1)
        if score < self.p:
            self._current_node = self._initial_node
        else:
            new_node = self.get_random_neighbor(graph, self._current_node)
            self._sampled_nodes.add(new_node)
            self._current_node = new_node

### MF Data Loader

In [None]:
import torch
import numpy as np
from collections import defaultdict
from torch.multiprocessing import Process, Queue, Event

class CombinedBase:
    def __init__(self): pass

    def __len__(self): return (self.num_interactions // self.batch_size) + 1

    def __del__(self):
        try:
            self.p.terminate() ; self.p.join()
        except: pass

    def make_user_history(self, data):
        user_history = [ [] for _ in range(self.num_users) ]
        for u, i, r in data: user_history[u].append(i)
        return user_history

    def pad(self, arr, max_len = None, pad_with = -1, side = 'right'):
        seq_len = max_len if max_len is not None else max(map(len, arr))
        seq_len = min(seq_len, 200) # You don't need more than this

        for i in range(len(arr)):
            while len(arr[i]) < seq_len: 
                pad_elem = arr[i][-1] if len(arr[i]) > 0 else 0
                pad_elem = pad_elem if pad_with == -1 else pad_with
                if side == 'right': arr[i].append(pad_elem)
                else: arr[i] = [ pad_elem ] + arr[i]
            arr[i] = arr[i][-seq_len:] # Keep last `seq_len` items
        return arr

    def sequential_pad(self, arr, hyper_params):
        # Padding left side so that we can simply take out [:, -1, :] in the output
        return self.pad(
            arr, max_len = hyper_params['max_seq_len'], 
            pad_with = hyper_params['total_items'], side = 'left'
        )

    def scatter(self, batch, tensor_kind, last_dimension):
        ret = tensor_kind(len(batch), last_dimension).zero_()

        if not torch.is_tensor(batch):
            if ret.is_cuda: batch = torch.cuda.LongTensor(batch)
            else: batch = torch.LongTensor(batch)

        return ret.scatter_(1, batch, 1)

    # NOTE: is_negative(user, item) is a function which tells 
    # if the item is a negative item for the user
    def sample_negatives(self, num_points, num_negs, is_negative):
        # Sample all the random numbers you need at once as this is much faster than 
        # calling random.randint() once everytime
        random_numbers = np.random.randint(
            self.num_items, 
            size = int(num_points * num_negs * 1.5)
        )

        negatives, at = [], 0
        for u in range(num_points):
            temp_negatives = []
            while len(temp_negatives) < num_negs:
                ## Negatives not possible
                if at >= len(random_numbers):
                    temp_negatives.append(0)
                    continue

                random_item = random_numbers[at] ; at += 1
                if is_negative(u, random_item):
                    # allowing duplicates, rare possibility
                    temp_negatives.append(random_item)
            negatives.append(temp_negatives)

        return negatives

    # So that training, GPU copying etc. 
    # doesn't have to wait for negative sampling
    def init_background_sampler(self, function):
        self.event = Event()
        self.result_queue = Queue(maxsize=4)
        
        def sample(result_queue):
            try:
                while True:
                    result_queue.put(function())
                    self.event.wait()
            except KeyboardInterrupt: pass
        self.p = Process(target = sample, args=(self.result_queue, ))
        self.p.daemon = True ; self.p.start()

class BaseTrainDataset(CombinedBase):
    def __init__(self, data, hyper_params):
        self.hyper_params = hyper_params
        self.batch_size = hyper_params['batch_size']
        self.implicit_task = hyper_params['task'] in [ 'implicit', 'sequential' ]
        self.data = data
        self.num_users, self.num_items = hyper_params['total_users'], hyper_params['total_items']
        
        ## Making user histories because sequential models require this
        self.user_history = self.make_user_history(data)
        
        ## Making sets of history for easier finding
        self.user_history_set = list(map(set, self.user_history))

        ## For computing PSP-metrics
        self.item_propensity = self.get_item_propensity()

    def get_item_count_map(self):
        item_count = defaultdict(int)
        for u, i, r in self.data: item_count[i] += 1
        return item_count

    def get_item_propensity(self, A = 0.55, B = 1.5):
        item_freq_map = self.get_item_count_map()
        item_freq = [ item_freq_map[i] for i in range(self.num_items) ]
        num_instances = len(self.data)

        C = (np.log(num_instances)-1)*np.power(B+1, A)
        wts = 1.0 + C*np.power(np.array(item_freq)+B, -A)
        return np.ravel(wts)

class BaseTestDataset(CombinedBase):
    def __init__(self, data, train_data, hyper_params, val_data):
        self.hyper_params = hyper_params
        self.batch_size = hyper_params['batch_size']
        self.implicit_task = hyper_params['task'] in [ 'implicit', 'sequential' ]
        self.data, self.train_data = data, train_data
        self.num_users, self.num_items = hyper_params['total_users'], hyper_params['total_items']
        
        ## Making user histories because sequential models require this
        self.train_user_history = self.make_user_history(train_data)
        if val_data is not None: 
            self.val_user_history = self.make_user_history(val_data)
            for u in range(self.num_users): self.train_user_history[u] += self.val_user_history[u]
        self.test_user_history = self.make_user_history(data)

        ## Making sets of history for easier finding
        self.train_user_history_set = list(map(set, self.train_user_history))
        self.test_user_history_set = list(map(set, self.test_user_history))

In [None]:
import torch
import numpy as np

# from data_loaders.base import BaseTrainDataset, BaseTestDataset
# from torch_utils import LongTensor, FloatTensor, is_cuda_available

class TrainDataset(BaseTrainDataset):
    def __init__(self, data, hyper_params, track_events):
        super(TrainDataset, self).__init__(data, hyper_params)
        self.shuffle_allowed = not track_events

        # Copying ENTIRE dataset to GPU
        self.users_cpu = list(map(lambda x: x[0], data))
        self.users = LongTensor(self.users_cpu)
        self.items = LongTensor(list(map(lambda x: x[1], data)))
        self.ratings = FloatTensor(list(map(lambda x: x[2], data)))

        self.num_interactions = len(data)

        self.init_background_sampler(
            lambda : torch.LongTensor(self.sample_negatives(
                len(self.data), self.hyper_params['num_train_negs'],
                lambda point, random_neg: random_neg not in self.user_history_set[self.users_cpu[point]]
            ))
        )

    def __iter__(self):
        # Important for optimal and stable performance
        indices = np.arange(self.num_interactions)
        if self.shuffle_allowed: np.random.shuffle(indices)
        temp_users = self.users[indices] ; temp_items = self.items[indices] ; temp_ratings = self.ratings[indices]

        if self.implicit_task: 
            negatives = self.result_queue.get()[indices]
            if is_cuda_available: negatives = negatives.cuda()
            self.event.set()

        for i in range(0, self.num_interactions, self.batch_size):
            yield [ 
                temp_users[i:i+self.batch_size], 
                temp_items[i:i+self.batch_size].unsqueeze(-1), 
                negatives[i:i+self.batch_size] if self.implicit_task else None, 
            ], temp_ratings[i:i+self.batch_size]

class TestDataset(BaseTestDataset):
    def __init__(self, data, train_data, hyper_params, val_data = None, test_set = False):
        super(TestDataset, self).__init__(data, train_data, hyper_params, val_data)
        self.test_set = test_set

        if self.implicit_task:
            # Padding for easier scattering
            self.test_user_history = LongTensor(self.pad(self.test_user_history))
            self.train_user_history = list(map(lambda x: LongTensor(x), self.train_user_history))

            # Copying all user-IDs to GPU
            self.all_users = LongTensor(list(range(self.num_users)))

            self.partial_eval = (not test_set) and hyper_params['partial_eval']

            def one_sample():
                negatives = self.sample_negatives(
                    self.num_users, self.hyper_params['num_test_negs'],
                    lambda point, random_neg: random_neg not in self.train_user_history_set[point] and \
                                              random_neg not in self.test_user_history_set[point]
                )
                if self.partial_eval: negatives = torch.LongTensor(negatives) # Sampled ranking
                else: negatives = np.array(negatives) # Sampled AUC
                return negatives

            self.init_background_sampler(one_sample)

        else:
            self.users = LongTensor(list(map(lambda x: x[0], data)))
            self.items = LongTensor(list(map(lambda x: x[1], data)))
            self.ratings = FloatTensor(list(map(lambda x: x[2], data)))

        self.num_interactions = self.num_users if self.implicit_task else len(data)

    def __iter__(self):
        if self.implicit_task:
            negatives = self.result_queue.get() ; self.event.set()
            if self.partial_eval and is_cuda_available: negatives = negatives.cuda()

        for u in range(0, self.num_interactions, self.batch_size):
            if self.implicit_task:
                batch             = self.all_users[u:u+self.batch_size]
                train_positive    = self.train_user_history[u:u+self.batch_size]
                test_positive     = self.test_user_history[u:u+self.batch_size]
                test_positive_set = self.test_user_history_set[u:u+self.batch_size]
                test_negative     = negatives[u:u+self.batch_size]

                yield [ batch, test_positive if self.partial_eval else None, test_negative ], [ 
                    train_positive,
                    test_positive_set,
                ]
            else:
                yield [ 
                    self.users[u:u+self.batch_size], 
                    self.items[u:u+self.batch_size].unsqueeze(-1), 
                    None, 
                ], self.ratings[u:u+self.batch_size]

### Data Load

In [None]:
import h5py
import numpy as np

# from utils import get_data_loader_class
# from data_path_constants import get_data_path, get_index_path

def load_data(hyper_params, track_events = False):
    rating_data_path = get_data_path(hyper_params)
    index_path = get_index_path(hyper_params)

    data_holder = DataHolder(rating_data_path, index_path)
    print("# of users: {}\n# of items: {}".format(data_holder.num_users, data_holder.num_items))

    hyper_params['total_users']  = data_holder.num_users
    hyper_params['total_items']  = data_holder.num_items
    # Do a partial item-space evaluation (only on the validation set)
    # if the dataset has too many items
    hyper_params['partial_eval'] = hyper_params['total_items'] > 1_000

    train_loader_class, test_loader_class = get_data_loader_class(hyper_params)
    
    send_val = hyper_params['model_type'] in [ 'SASRec', 'SVAE', 'MVAE' ]

    return train_loader_class(data_holder.train, hyper_params, track_events), test_loader_class(
        data_holder.test, data_holder.train, hyper_params, test_set = True,
        val_data = data_holder.val if send_val else None
    ), test_loader_class(data_holder.val, data_holder.train, hyper_params), hyper_params

class DataHolder:
    def __init__(self, rating_data_path, index_path):
        with h5py.File(rating_data_path + "total_data.hdf5", 'r') as f:
            self.data = list(zip(f['user'][:], f['item'][:], f['rating'][:]))

        self.index = np.load(index_path + "index.npz")['data']
        self.remap()

    def remap(self):
        ## Counting number of unique users/items before
        valid_users, valid_items = set(), set()
        for at, (u, i, r) in enumerate(self.data):
            if self.index[at] != -1:
                valid_users.add(u)
                valid_items.add(i)

        ## Map creation done!
        user_map = dict(zip(list(valid_users), list(range(len(valid_users)))))
        item_map = dict(zip(list(valid_items), list(range(len(valid_items)))))

        new_data, new_index = [], []
        for at, (u, i, r) in enumerate(self.data):
            if self.index[at] == -1: continue
            new_data.append([ user_map[u], item_map[i], r ])
            new_index.append(self.index[at])

        self.data = new_data
        self.index = new_index
        self.num_users = len(valid_users)
        self.num_items = len(valid_items)

    def select(self, index_val):
        ret = []
        for at, tup in enumerate(self.data):
            if self.index[at] == index_val: ret.append(tup)
        return ret

    @property
    def train(self): return self.select(0)

    @property
    def val(self): return self.select(1)

    @property
    def test(self): return self.select(2)

    @property
    def num_train_interactions(self): return int(sum(map(lambda x: x == 0, self.index)))

    @property
    def num_val_interactions(self): return int(sum(map(lambda x: x == 1, self.index)))

    @property
    def num_test_interactions(self): return int(sum(map(lambda x: x == 2, self.index)))

### Evaluate

In [None]:
import torch
import numpy as np
from numba import jit, float32, float64, int64

# from utils import INF

def evaluate(model, criterion, reader, hyper_params, item_propensity, topk = [ 10, 100 ], test = False):
    metrics = {}

    # Do a negative sampled item-space evaluation (only on the validation set)
    # if the dataset is too big 
    partial_eval = (not test) and hyper_params['partial_eval'] 
    partial_eval = partial_eval and (hyper_params['model_type'] not in [ 'MVAE', 'SVAE', 'pop_rec' ])
    if partial_eval: metrics['eval'] = 'partial'

    if hyper_params['task'] == 'explicit': metrics['MSE'] = 0.0
    else:
        preds, y_binary = [], []
        for kind in [ 'HR', 'NDCG', 'PSP' ]:
            for k in topk: 
                metrics['{}@{}'.format(kind, k)] = 0.0

    model.eval()
    with torch.no_grad():
        for data, y in reader:
            output = model(data, eval = True)
            if hyper_params['model_type'] in [ 'MVAE', 'SVAE' ]: output, _, _ = output
            if hyper_params['model_type'] == 'SVAE': output = output[:, -1, :]

            if hyper_params['task'] == 'explicit': 
                metrics['MSE'] += torch.sum(criterion(output, y, return_mean = False).data)
            else:
                function = evaluate_batch_partial if partial_eval else evaluate_batch

                metrics, temp_preds, temp_y = function(data, output, y, item_propensity, topk, metrics)
                preds += temp_preds
                y_binary += temp_y

    if hyper_params['task'] == 'explicit':
        metrics['MSE'] = round(float(metrics['MSE']) / reader.num_interactions, 4)
    else:
        # NOTE: sklearn's `roc_auc_score` is suuuuper slow
        metrics['AUC'] = round(fast_auc(np.array(y_binary), np.array(preds)), 4)
        
        for kind in [ 'HR', 'NDCG', 'PSP' ]:
            for k in topk: 
                metrics['{}@{}'.format(kind, k)] = round(
                    float(100.0 * metrics['{}@{}'.format(kind, k)]) / reader.num_interactions, 4
                )

    return metrics

def evaluate_batch(data, output_batch, y, item_propensity, topk, metrics):
    # Y
    train_positive, test_positive_set = y

    # Data
    _, _, auc_negatives = data

    # AUC Stuff
    temp_preds, temp_y = [], []
    logits_cpu = output_batch.cpu().numpy()
    for b in range(len(output_batch)):
        # Validation set could have 0 positive interactions
        if len(test_positive_set[b]) == 0: continue

        temp_preds += np.take(logits_cpu[b], np.array(list(test_positive_set[b]))).tolist()
        temp_y += [ 1.0 for _ in range(len(test_positive_set[b])) ]

        temp_preds += np.take(logits_cpu[b], auc_negatives[b]).tolist()
        temp_y += [ 0.0 for _ in range(len(auc_negatives[b])) ]

    # Marking train-set consumed items as negative INF
    for b in range(len(output_batch)): output_batch[b][ train_positive[b] ] = -INF

    _, indices = torch.topk(output_batch, min(item_propensity.shape[0], max(topk)), sorted = True)
    indices = indices.cpu().numpy().tolist()

    for k in topk: 
        for b in range(len(output_batch)):
            num_pos = float(len(test_positive_set[b]))
            # Validation set could have 0 positive interactions after sampling
            if num_pos == 0: continue

            metrics['HR@{}'.format(k)] += float(len(set(indices[b][:k]) & test_positive_set[b])) / float(min(num_pos, k))

            test_positive_sorted_psp = sorted([ item_propensity[x] for x in test_positive_set[b] ])[::-1]

            dcg, idcg, psp, max_psp = 0.0, 0.0, 0.0, 0.0
            for at, pred in enumerate(indices[b][:k]):
                if pred in test_positive_set[b]: 
                    dcg += 1.0 / np.log2(at + 2)
                    psp += float(item_propensity[pred]) / float(min(num_pos, k))
                if at < num_pos: 
                    idcg += 1.0 / np.log2(at + 2)
                    max_psp += test_positive_sorted_psp[at]

            metrics['NDCG@{}'.format(k)] += dcg / idcg
            metrics['PSP@{}'.format(k)] += psp / max_psp

    return metrics, temp_preds, temp_y

def evaluate_batch_partial(data, output, y, item_propensity, topk, metrics):
    _, test_pos_items, _ = data
    test_pos_items = test_pos_items.cpu().numpy()

    pos_score, neg_score = output
    pos_score, neg_score = pos_score.cpu().numpy(), neg_score.cpu().numpy()

    temp_preds, temp_y, hr, ndcg, psp = evaluate_batch_partial_jit(
        pos_score, neg_score, test_pos_items, np.array(item_propensity), np.array(topk)
    )

    for at_k, k in enumerate(topk): 
        metrics['HR@{}'.format(k)] += hr[at_k]
        metrics['NDCG@{}'.format(k)] += ndcg[at_k]
        metrics['PSP@{}'.format(k)] += psp[at_k]

    return metrics, temp_preds.tolist(), temp_y.tolist()

@jit('Tuple((float32[:], float32[:], float32[:], float32[:], float32[:]))(float32[:,:], float32[:,:], int64[:,:], float64[:], int64[:])')
def evaluate_batch_partial_jit(pos_score, neg_score, test_pos_items, item_propensity, topk):
    temp_preds = np.zeros(
        ((pos_score.shape[0] * pos_score.shape[1]) + (neg_score.shape[0] * neg_score.shape[1])), 
        dtype = np.float32
    )
    temp_y = np.zeros(temp_preds.shape, dtype = np.float32)
    at_preds = 0

    hr_arr = np.zeros((len(topk)), dtype = np.float32)
    ndcg_arr = np.zeros((len(topk)), dtype = np.float32)
    psp_arr = np.zeros((len(topk)), dtype = np.float32)

    for b in range(len(pos_score)):
        pos, neg = pos_score[b, :], neg_score[b, :]

        # pos will be padded, un-pad it
        last_index = len(pos) - 1
        while last_index > 0 and pos[last_index] == pos[last_index - 1]: last_index -= 1
        pos = pos[:last_index + 1]

        # Add to AUC
        temp_preds[at_preds:at_preds+len(pos)] = pos
        temp_y[at_preds:at_preds+len(pos)] = 1
        at_preds += len(pos)

        temp_preds[at_preds:at_preds+len(neg)] = neg
        temp_y[at_preds:at_preds+len(neg)] = 0
        at_preds += len(neg)

        # get rank of all elements in pos
        temp_ranks = np.argsort(- np.concatenate((pos, neg)))

        # To maintain order
        pos_ranks = np.zeros(len(pos))
        for at, r in enumerate(temp_ranks):
            if r < len(pos): pos_ranks[r] = at + 1

        test_positive_sorted_psp = sorted([ item_propensity[x] for x in test_pos_items[b] ])[::-1]

        for at_k, k in enumerate(topk): 
            num_pos = float(len(pos))
            
            hr_arr[at_k] += np.sum(pos_ranks <= k) / float(min(num_pos, k))

            dcg, idcg, psp, max_psp = 0.0, 0.0, 0.0, 0.0
            for at, rank in enumerate(pos_ranks):
                if rank <= k:
                    dcg += 1.0 / np.log2(rank + 1) # 1-based indexing
                    psp += item_propensity[test_pos_items[b][at]] / float(min(num_pos, k))
                idcg += 1.0 / np.log2(at + 2)
                max_psp += test_positive_sorted_psp[at]
            
            ndcg_arr[at_k] += dcg / idcg
            psp_arr[at_k] += psp / max_psp

    return temp_preds[:at_preds], temp_y[:at_preds], hr_arr, ndcg_arr, psp_arr

@jit(float64(float64[:], float64[:]))
def fast_auc(y_true, y_prob):
    y_true = y_true[np.argsort(y_prob)]
    nfalse, auc = 0, 0
    for i in range(len(y_true)):
        nfalse += (1 - y_true[i])
        auc += y_true[i] * nfalse
    return auc / (nfalse * (len(y_true) - nfalse))

### Loss

In [None]:
import torch
import torch.nn.functional as F

# from torch_utils import is_cuda_available

class CustomLoss(torch.nn.Module):
    def __init__(self, hyper_params):
        super(CustomLoss, self).__init__()
        self.forward = {
            'explicit': self.mse,
            'implicit': self.bpr,
            'sequential': self.bpr,
        }[hyper_params['task']]

        if hyper_params['model_type'] == "MVAE": self.forward = self.vae_loss
        if hyper_params['model_type'] == "SVAE": self.forward = self.svae_loss
        if hyper_params['model_type'] == "SASRec": self.forward = self.bce_sasrec

        self.torch_bce = torch.nn.BCEWithLogitsLoss()
        self.anneal_val = 0.0
        self.hyper_params = hyper_params

    def mse(self, output, y, return_mean = True):
        mse = torch.pow(output - y, 2)
                
        if return_mean: return torch.mean(mse)
        return mse

    def bce_sasrec(self, output, pos, return_mean = True):
        pos_logits, neg_logits = output
        pos_labels, neg_labels = torch.ones(pos_logits.shape), torch.zeros(neg_logits.shape)
        if is_cuda_available: pos_labels, neg_labels = pos_labels.cuda(), neg_labels.cuda()

        indices = pos != self.hyper_params['total_items']

        loss = self.torch_bce(pos_logits[indices], pos_labels[indices])
        loss += self.torch_bce(neg_logits[indices], neg_labels[indices])
        return loss

    def bpr(self, output, y, return_mean = True):
        pos_output, neg_output = output
        pos_output = pos_output.repeat(1, neg_output.shape[1]).view(-1)
        neg_output = neg_output.view(-1)
        
        loss = -F.logsigmoid(pos_output - neg_output)
                
        if return_mean: return torch.mean(loss)
        return loss

    def anneal(self, step_size):
        self.anneal_val += step_size
        self.anneal_val = max(self.anneal_val, 0.2)

    def vae_loss(self, output, y_true_s, return_mean = True):
        decoder_output, mu_q, logvar_q = output

        # Calculate KL Divergence loss
        kld = torch.mean(torch.sum(0.5 * (-logvar_q + torch.exp(logvar_q) + mu_q**2 - 1), -1))
    
        # Calculate Likelihood
        decoder_output = F.log_softmax(decoder_output, -1)
        likelihood = torch.sum(-1.0 * y_true_s * decoder_output, -1)
        
        final = (self.anneal_val * kld) + (likelihood)
        
        if return_mean: return torch.mean(final)
        return final

    def svae_loss(self, output, y, return_mean = True):
        decoder_output, mu_q, logvar_q = output
        dec_shape = decoder_output.shape # [batch_size x seq_len x total_items]

        # Calculate KL Divergence loss
        kld = torch.mean(torch.sum(0.5 * (-logvar_q + torch.exp(logvar_q) + mu_q**2 - 1), -1))
    
        # Don't compute loss on padded items
        y_true_s, y_indices = y
        keep_indices = y_indices != self.hyper_params['total_items']
        y_true_s = y_true_s[keep_indices]
        decoder_output = decoder_output[keep_indices]

        # Calculate Likelihood
        decoder_output = F.log_softmax(decoder_output, -1)
        likelihood = torch.sum(-1.0 * y_true_s * decoder_output)
        likelihood = likelihood / float(dec_shape[0] * self.hyper_params['num_next'])
        
        final = (self.anneal_val * kld) + (likelihood)
        
        if return_mean: return torch.mean(final)
        return final

### Pytorch

In [None]:
def train(model, criterion, optimizer, reader, hyper_params, forgetting_events, track_events):
    import torch

    model.train()
    
    # Initializing metrics since we will calculate MSE on the train set on the fly
    metrics = {}
    
    # Initializations
    at = 0
    
    # Train for one epoch, batch-by-batch
    loop = tqdm(reader)
    for data, y in loop:
        # Empty the gradients
        model.zero_grad()
        optimizer.zero_grad()
    
        # Forward pass
        output = model(data)

        # Compute per-interaction loss
        loss = criterion(output, y, return_mean = False)
        criterion.anneal(1.0 / float(len(reader) * hyper_params['epochs']))

        # loop.set_description("Loss: {}".format(round(float(loss), 4)))
        
        # Track forgetting events
        if track_events:
            with torch.no_grad():
                if hyper_params['task'] == 'explicit': forgetting_events[at : at+data[0].shape[0]] += loss.data
                else:
                    pos_output, neg_output = output
                    pos_output = pos_output.repeat(1, neg_output.shape[1])
                    num_incorrect = torch.sum((neg_output > pos_output).float(), -1)
                    forgetting_events[at : at+data[0].shape[0]] += num_incorrect.data
                    
                at += data[0].shape[0]

        # Backward pass
        loss = torch.mean(loss)
        loss.backward()
        optimizer.step()

    return metrics, forgetting_events

In [None]:
def train_complete(hyper_params, train_reader, val_reader, model, model_class, track_events):
    import torch

    # from loss import CustomLoss
    # from eval import evaluate
    # from torch_utils import is_cuda_available

    criterion = CustomLoss(hyper_params)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=hyper_params['lr'], betas=(0.9, 0.98),
        weight_decay=hyper_params['weight_decay']
    )

    file_write(hyper_params['log_file'], str(model))
    file_write(hyper_params['log_file'], "\nModel Built!\nStarting Training...\n")

    try:
        best_MSE = float(INF)
        best_AUC = -float(INF)
        best_HR = -float(INF)
        decreasing_streak = 0
        forgetting_events = None
        if track_events: 
            forgetting_events = torch.zeros(train_reader.num_interactions).float()
            if is_cuda_available: forgetting_events = forgetting_events.cuda()

        for epoch in range(1, hyper_params['epochs'] + 1):
            epoch_start_time = time.time()
            
            # Training for one epoch
            metrics, local_forgetted_count = train(
                model, criterion, optimizer, train_reader, hyper_params, 
                forgetting_events, track_events
            )

            # Calulating the metrics on the validation set
            if (epoch % hyper_params['validate_every'] == 0) or (epoch == 1):
                metrics = evaluate(model, criterion, val_reader, hyper_params, train_reader.item_propensity)
                metrics['dataset'] = hyper_params['dataset']
                decreasing_streak += 1

                # Save best model on validation set
                if hyper_params['task'] == 'explicit' and metrics['MSE'] < best_MSE:
                    print("Saving model...")
                    torch.save(model.state_dict(), hyper_params['model_path'])
                    decreasing_streak, best_MSE = 0, metrics['MSE']
                elif hyper_params['task'] != 'explicit' and metrics['AUC'] > best_AUC:
                    print("Saving model...")
                    torch.save(model.state_dict(), hyper_params['model_path'])
                    decreasing_streak, best_AUC = 0, metrics['AUC']
                elif hyper_params['task'] != 'explicit' and metrics['HR@10'] > best_HR:
                    print("Saving model...")
                    torch.save(model.state_dict(), hyper_params['model_path'])
                    decreasing_streak, best_HR = 0, metrics['HR@10']
            
            log_end_epoch(hyper_params, metrics, epoch, time.time() - epoch_start_time, metrics_on = '(VAL)')

            # Check if need to early-stop
            if 'early_stop' in hyper_params and decreasing_streak >= hyper_params['early_stop']:
                file_write(hyper_params['log_file'], "Early stopping..")
                break
            
    except KeyboardInterrupt: print('Exiting from training early')

    # Load best model and return it for evaluation on test-set
    if os.path.exists(hyper_params['model_path']):
        model = model_class(hyper_params)
        if is_cuda_available: model = model.cuda()
        model.load_state_dict(torch.load(hyper_params['model_path']))
    
    model.eval()

    if track_events: forgetting_events = forgetting_events.cpu().numpy() / float(hyper_params['epochs'])

    return model, forgetting_events

In [None]:
import os
import time
import importlib
import datetime as dt
from tqdm import tqdm

# from utils import file_write, log_end_epoch, INF, valid_hyper_params
# from data_path_constants import get_log_file_path, get_model_file_path


def main_pytorch(hyper_params, track_events = False, eval_full = True):
    # from load_data import load_data
    # from eval import evaluate
    
    # from torch_utils import is_cuda_available, xavier_init, get_model_class
    # from loss import CustomLoss

    if not valid_hyper_params(hyper_params): 
        print("Invalid task combination specified, exiting.")
        return

    # Load the data readers
    train_reader, test_reader, val_reader, hyper_params = load_data(hyper_params, track_events = track_events)
    file_write(hyper_params['log_file'], "\n\nSimulation run on: " + str(dt.datetime.now()) + "\n\n")
    file_write(hyper_params['log_file'], "Data reading complete!")
    file_write(hyper_params['log_file'], "Number of train batches: {:4d}".format(len(train_reader)))
    file_write(hyper_params['log_file'], "Number of validation batches: {:4d}".format(len(val_reader)))
    file_write(hyper_params['log_file'], "Number of test batches: {:4d}".format(len(test_reader)))

    # Initialize & train the model
    start_time = time.time()

    if hyper_params['model_type'] == 'NeuMF': 
        model, forgetting_events = train_neumf(hyper_params, train_reader, val_reader, track_events)
    else:
        model = get_model_class(hyper_params)(hyper_params)
        if is_cuda_available: model = model.cuda()
        xavier_init(model)
        model, forgetting_events = train_complete(
            hyper_params, train_reader, val_reader, model, get_model_class(hyper_params), track_events
        )

    metrics = {}
    if eval_full:
        # Calculating MSE on test-set
        criterion = CustomLoss(hyper_params)
        metrics = evaluate(model, criterion, test_reader, hyper_params, train_reader.item_propensity, test = True)
        log_end_epoch(hyper_params, metrics, 'final', time.time() - start_time, metrics_on = '(TEST)')

    # We have no space left for storing the models
    os.remove(hyper_params['model_path'])
    del model, train_reader, test_reader, val_reader
    return metrics, forgetting_events

### SVPHandler

In [None]:
import numpy as np
from collections import defaultdict

# from main import main_pytorch

class SVPHandler:
    def __init__(self, model_type, loss_type, hyper_params):
        hyper_params['model_type'] = model_type
        hyper_params['task'] = loss_type
        hyper_params['num_train_negs'] = 1
        hyper_params['num_test_negs'] = 100

        hyper_params['latent_size'] = 10
        hyper_params['dropout'] = 0.3
        hyper_params['weight_decay'] = float(1e-6)
        hyper_params['lr'] = 0.006
        hyper_params['epochs'] = 50
        hyper_params['validate_every'] = 5000
        hyper_params['batch_size'] = 1024
        self.hyper_params = hyper_params
        self.hyper_params['log_file'] = self.log_file
        self.hyper_params['model_path'] = self.model_file

        self.train_model()

    def train_model(self): 
        _, self.forgetted_count = main_pytorch(self.hyper_params, track_events = True, eval_full = False)

    def forgetting_events(self, percent, data, index):
        # Keep those points which have the maximum forgetted count
        # => Remove those points which have the minimum forgetted count
        index_map = []
        for at, i in enumerate(index):
            if i == 0: index_map.append(at)

        split_point = int(float(len(self.forgetted_count)) * (float(percent) / 100.0))
        order = np.argsort(self.forgetted_count)
        order = list(map(lambda x: index_map[x], order))
        remove_indices = order[:split_point] # If greedy

        for i in remove_indices: index[i] = -1 # Remove
        return index

    def forgetting_events_user(self, percent, data, index):
        # Keep those users which have the maximum forgetted count
        # Remove those users which have the minimum forgetted count
        index_map, user_map, hist, at, total = [], [], {}, 0, 0
        for u in range(len(data)):
            for i, r, t in data[u]:
                if index[at] == 0:
                    index_map.append(at)
                    user_map.append(u)
                    if u not in hist: hist[u] = 0
                    hist[u] += 1
                    total += 1
                at += 1

        user_forgetted_count = defaultdict(list)
        for train_at, cnt in enumerate(self.forgetted_count):
            user_forgetted_count[user_map[train_at]].append(cnt)
        user_forgetted_count = sorted(list(dict(user_forgetted_count).items()), key = lambda x: np.mean(x[1]))

        interactions_to_remove, removed, users_to_remove = total * (float(percent) / 100.0), 0, set()
        for u, _ in user_forgetted_count:
            if removed >= interactions_to_remove: break
            users_to_remove.add(u)
            removed += hist[u]

        for train_at in range(len(user_map)):
            if user_map[train_at] in users_to_remove: index[index_map[train_at]] = -1

        return index

    def compute_freq(self, data, index, freq_type):
        freq, at = defaultdict(int), 0
        for u in range(len(data)):
            for i, r, t in data[u]:
                if index[at] == 0:
                    to_count = [ u, i ][freq_type]
                    freq[to_count] += 1
                at += 1

        valid_users = list(freq.keys())
        return list(map(lambda x: freq[x], valid_users)), dict(zip(valid_users, list(range(len(freq)))))

    def compute_prop(self, freq_vector, num_instances, A = 0.55, B = 1.5):
        C = (np.log(num_instances)-1)*np.power(B+1, A)
        wts = 1.0 + C*np.power(np.array(freq_vector)+B, -A)
        return np.ravel(wts)

    def forgetting_events_propensity(self, percent, data, index, pooling_method = 'max'):
        # Keep those points which have the maximum forgetted count
        # Remove those points which have the minimum forgetted count

        num_interactions = len(self.forgetted_count)
        user_freq, user_map = self.compute_freq(data, index, 0)
        user_propensity_vector = self.compute_prop(user_freq, num_interactions)
        item_freq, item_map = self.compute_freq(data, index, 1)
        item_propensity_vector = self.compute_prop(item_freq, num_interactions)
        interaction_propensity, at = [], 0
        freq, at = defaultdict(int), 0
        
        def pool(prop_u, prop_i):
            if pooling_method == 'sum': return prop_u + prop_i
            elif pooling_method == 'max': return max(prop_u, prop_i)

        for u in range(len(data)):
            for i, r, t in data[u]:
                if index[at] == 0:
                    interaction_propensity.append(
                        pool(user_propensity_vector[user_map[u]], item_propensity_vector[item_map[i]])
                    )
                at += 1
        assert len(interaction_propensity) == num_interactions

        # interaction_propensity actually estimates the `inverse` propensity, hence multiply
        updated_count = np.array(self.forgetted_count) * np.array(interaction_propensity)

        index_map = []
        for at, i in enumerate(index):
            if i == 0: index_map.append(at)

        split_point = int(float(len(updated_count)) * (float(percent) / 100.0))
        order = np.argsort(updated_count)
        order = list(map(lambda x: index_map[x], order))
        remove_indices = order[:split_point] # If greedy

        for i in remove_indices: index[i] = -1 # Remove
        return index

    def forgetting_events_user_propensity(self, percent, data, index):
        # Keep those users which have the maximum forgetted count
        # Keep those users which have the maximum propensity --> minimum frequency
        # Remove those users which have the minimum forgetted count

        num_interactions = len(self.forgetted_count)
        user_freq, user_index_map = self.compute_freq(data, index, 0)
        user_propensity_vector = self.compute_prop(user_freq, num_interactions)

        index_map, user_map, hist, at, total = [], [], {}, 0, 0
        for u in range(len(data)):
            for i, r, t in data[u]:
                if index[at] == 0:
                    index_map.append(at)
                    user_map.append(u)
                    if u not in hist: hist[u] = 0
                    hist[u] += 1
                    total += 1
                at += 1

        user_forgetted_count = defaultdict(list)
        for train_at, cnt in enumerate(self.forgetted_count):
            u = user_map[train_at]
            user_forgetted_count[u].append(cnt * user_propensity_vector[user_index_map[u]])
        user_forgetted_count = sorted(list(dict(user_forgetted_count).items()), key = lambda x: np.mean(x[1]))

        interactions_to_remove, removed, users_to_remove = total * (float(percent) / 100.0), 0, set()
        for u, _ in user_forgetted_count:
            if removed >= interactions_to_remove: break
            users_to_remove.add(u)
            removed += hist[u]

        for train_at in range(len(user_map)):
            if user_map[train_at] in users_to_remove: index[index_map[train_at]] = -1

        return index

    @property
    def model_file(self): 
        return get_model_file_path(self.hyper_params)

    @property
    def log_file(self): 
        return get_log_file_path(self.hyper_params)

### RatingData

In [None]:
import os
import h5py
import json
import math
import random
import numpy as np
from collections import defaultdict

import networkx as nx
import networkit as nk
nk.setNumberOfThreads(16)

# from graph_sampling.ForestFire import ForestFireSampler
# from graph_sampling.RW import RandomWalkWithRestartSampler

class rating_data:
    def __init__(self, data):
        self.data = data

        self.index = [] # 0: train, 1: validation, 2: test, -1: removed/ignore
        for user_data in self.data:
            for _ in range(len(user_data)): self.index.append(42)

        self.complete_data_stats = None

    def train_test_split(self, split_type):
        at = 0

        for user in range(len(self.data)):
            if split_type == "20_percent_hist": 
                first_split_point = int(0.8 * len(self.data[user]))
                second_split_point = int(0.9 * len(self.data[user]))

                indices = np.arange(len(self.data[user]))
                np.random.shuffle(indices)

                for timestep, (item, rating, time) in enumerate(self.data[user]):
                    if len(self.data[user]) < 3: self.index[at] = -1
                    else:
                        # Force atleast one element in user history to be in test
                        if timestep == indices[0]: self.index[at] = 2
                        else:
                            if timestep in indices[:first_split_point]: self.index[at] = 0
                            elif timestep in indices[first_split_point:second_split_point]: self.index[at] = 1
                            else: self.index[at] = 2
                    at += 1
            
            elif split_type == "leave_2":
                for timestep, (item, rating, time) in enumerate(self.data[user]):
                    if len(self.data[user]) < 3: self.index[at] = -1
                    else:
                        if timestep <= len(self.data[user]) - 3: self.index[at] = 0
                        elif timestep == len(self.data[user]) - 2: self.index[at] = 1
                        else: self.index[at] = 2
                    at += 1

        assert at == len(self.index)
        self.complete_data_stats = None

    def interaction_random_sample(self, percent):
        active, at = set(), 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                # NOTE: only sample on the train-set
                if self.index[at] == 0: active.add(at)
                at += 1
        active = list(active)

        # Remove `percent`% at random
        remove_mask = {}
        for i in active: remove_mask[i] = False
        random.shuffle(active)
        split_point = int(float(len(active)) * (float(percent) / 100.0))
        for i in active[:split_point]: remove_mask[i] = True

        at = 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                if remove_mask.get(at, False) and self.index[at] == 0: self.index[at] = -1
                at += 1
        assert at == len(self.index)

    def frequency_sample(self, percent, sample_type):
        hist, at = {}, 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                key = [ u, i ][sample_type]
                if key not in hist: hist[key] = []
                # NOTE: only sample on the train-set
                if self.index[at] == 0: hist[key].append(at) 
                at += 1

        # Remove `percent`% at random
        remove_mask = {}
        for key in hist:
            interactions = hist[key]
            random.shuffle(interactions)
            split_point = math.ceil(float(len(interactions)) * (float(percent) / 100.0))
            for i in interactions[:split_point]: remove_mask[i] = True

        at = 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                if remove_mask.get(at, False) and self.index[at] == 0: self.index[at] = -1
                at += 1
        assert at == len(self.index)

    def user_random_sample(self, percent):
        hist, at, total = {}, 0, 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                # NOTE: only sample on the train-set
                if self.index[at] == 0: 
                    if u not in hist: hist[u] = 0
                    hist[u] += 1
                    total += 1
                at += 1

        # Remove `percent`% at random
        user_freqs = list(hist.items()) ; np.random.shuffle(user_freqs)
        interactions_to_remove, removed, users_to_remove = total * (float(percent) / 100.0), 0, set()
        for u, cnt in user_freqs:
            if removed >= interactions_to_remove: break
            users_to_remove.add(u)
            removed += cnt

        at = 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                if u in users_to_remove and self.index[at] == 0: self.index[at] = -1
                at += 1
        assert at == len(self.index)

    def temporal_sample(self, percent):
        hist, at = {}, 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                if u not in hist: hist[u] = []
                # NOTE: only sample on the train-set
                if self.index[at] == 0: hist[u].append(at) 
                at += 1

        # Remove first `percent`% interactions for each user
        remove_mask = {}
        for u in hist:
            interactions = hist[u]
            # random.shuffle(interactions) ### No shuffling, remove first % interactions
            split_point = math.ceil(float(len(interactions)) * (float(percent) / 100.0))
            for i in interactions[:split_point]: remove_mask[i] = True

        at = 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                if remove_mask.get(at, False) and self.index[at] == 0: self.index[at] = -1
                at += 1
        assert at == len(self.index)

    def tail_user_remove(self, percent):
        hist, at, total = {}, 0, 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                # NOTE: only count on the train-set
                if self.index[at] == 0: 
                    if u not in hist: hist[u] = 0
                    hist[u] += 1
                    total += 1
                at += 1

        user_freqs = sorted(list(hist.items()), key = lambda x: x[1])
        interactions_to_remove, removed, users_to_remove = total * (float(percent) / 100.0), 0, set()
        for u, cnt in user_freqs:
            if removed >= interactions_to_remove: break
            users_to_remove.add(u)
            removed += cnt

        at = 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                if u in users_to_remove and self.index[at] == 0: self.index[at] = -1
                at += 1
        assert at == len(self.index)

    def svp_sample(self, percent, svp_handler, sampling_type):
        self.index = {
            'forgetting_events': svp_handler.forgetting_events,
            'forgetting_events_user': svp_handler.forgetting_events_user,
            'forgetting_events_propensity': svp_handler.forgetting_events_propensity,
            'forgetting_events_user_propensity': svp_handler.forgetting_events_user_propensity,
        }[sampling_type](percent, self.data, self.index)

    def construct_nx_graph(self):
        # Make graph
        g = nx.Graph()

        # Add nodes & edges
        user_map, item_map, rev_user_map, rev_item_map, at, node_num = {}, {}, {}, {}, 0, 0
        user_actions, item_actions, total = defaultdict(list), defaultdict(list), 0

        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                # NOTE: only sample on the train-set
                if self.index[at] == 0: 
                    total += 1
                    user_actions[u].append(at)
                    item_actions[i].append(at)

                    if u not in user_map:
                        user_map[u] = node_num
                        rev_user_map[node_num] = u
                        g.add_node(node_num)
                        node_num += 1
                    if i not in item_map:
                        item_map[i] = node_num
                        rev_item_map[node_num] = i
                        g.add_node(node_num)
                        node_num += 1
                    g.add_edge(user_map[u], item_map[i])
                at += 1
        assert node_num == g.number_of_nodes()
        return g, rev_user_map, rev_item_map, user_actions, item_actions

    def pagerank_sample(self, percent):
        # networkx graph
        g, rev_user_map, rev_item_map, user_actions, item_actions = self.construct_nx_graph()

        # Convert to networkit
        nk_g = nk.nxadapter.nx2nk(g)

        # Run pagerank
        pr = nk.centrality.PageRank(nk_g, 1e-6) ; pr.run()

        # Remove `percent`% acc to pagerank scores
        # THOUGHT: the nodes with the least pagerank scores will most probably be the tail users/items
        interactions_to_remove, removed = nk_g.numberOfEdges() * (float(percent) / 100.0), 0
        for node, _ in pr.ranking()[::-1]:
            if removed >= interactions_to_remove: break
            
            if node in rev_user_map: 
                for at in user_actions[rev_user_map[node]]:
                    if self.index[at] != -1: removed += 1
                    self.index[at] = -1
            else: 
                for at in item_actions[rev_item_map[node]]:
                    if self.index[at] != -1: removed += 1
                    self.index[at] = -1

    def random_walk_sample(self, percent):
        at, total = 0, 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                if self.index[at] == 0: total += 1
                at += 1

        interactions_to_remove, removed = float(total) * (float(percent) / 100.0), 0
        while removed < interactions_to_remove:
            # networkx graph
            nx_g, rev_user_map, rev_item_map, user_actions, item_actions = self.construct_nx_graph()
            
            # Create sampler ## Nodes to keep
            sampler = RandomWalkWithRestartSampler(number_of_nodes = int(nx_g.number_of_nodes() * (float(100 - percent) / 100.0)))
            sampler._create_initial_node_set(nx_g, None)

            # Sample
            while len(sampler._sampled_nodes) < sampler.number_of_nodes:
                sampler._do_a_step(nx_g)

            # Remove from the main graph
            ## `sampler._sampled_nodes` are the nodes that are kept, not removed
            nodes_to_remove = list(sampler._set_of_nodes.difference(sampler._sampled_nodes))

            for node in nodes_to_remove:
                if removed >= interactions_to_remove: break

                if node in rev_user_map: 
                    for at in user_actions[rev_user_map[node]]:
                        if self.index[at] != -1: removed += 1
                        self.index[at] = -1
                else: 
                    for at in item_actions[rev_item_map[node]]:
                        if self.index[at] != -1: removed += 1
                        self.index[at] = -1

    def forest_fire_sample(self, percent):
        at, total = 0, 0
        for u in range(len(self.data)):
            for i, r, t in self.data[u]:
                if self.index[at] == 0: total += 1
                at += 1

        interactions_to_remove, removed = float(total) * (float(percent) / 100.0), 0
        while removed < interactions_to_remove:
            # networkx graph
            nx_g, rev_user_map, rev_item_map, user_actions, item_actions = self.construct_nx_graph()
            
            # Create sampler ## Nodes to keep
            sampler = ForestFireSampler(number_of_nodes = int(nx_g.number_of_nodes() * (float(100 - percent) / 100.0)))
            sampler._create_node_sets(nx_g)

            # Sample
            while len(sampler._sampled_nodes) < sampler.number_of_nodes: 
                sampler._start_a_fire(nx_g)

            # Remove from the main graph
            ## `sampler._sampled_nodes` are the nodes that are kept, not removed
            nodes_to_remove = list(sampler._set_of_nodes.difference(sampler._sampled_nodes))

            for node in nodes_to_remove:
                if removed >= interactions_to_remove: break

                if node in rev_user_map: 
                    for at in user_actions[rev_user_map[node]]:
                        if self.index[at] != -1: removed += 1
                        self.index[at] = -1
                else: 
                    for at in item_actions[rev_item_map[node]]:
                        if self.index[at] != -1: removed += 1
                        self.index[at] = -1

    def measure_data_stats(self):
        num_users, num_items, num_interactions, num_test, num_val = set(), set(), 0, 0, 0
        at = 0
        for u in range(len(self.data)):
            for i, _, _ in self.data[u]:
                if self.index[at] == 0: num_interactions += 1
                if self.index[at] == 1: num_val += 1
                if self.index[at] == 2: num_test += 1

                if self.index[at] != -1:
                    num_users.add(u)
                    num_items.add(i)
                at += 1

        data_stats = {}
        data_stats["num_users"] = len(num_users)
        data_stats["num_items"] = len(num_items)
        data_stats["num_train_interactions"] = num_interactions
        data_stats["num_test"] = num_test
        data_stats["num_val"] = num_val

        return data_stats

    def save_index(self, path, statistics = True):
        os.makedirs(path, exist_ok = True)
        with open(path + "/index.npz", "wb") as f: np.savez_compressed(f, data = self.index)

        if statistics:
            data_stats = self.measure_data_stats() 
            if self.complete_data_stats is None: print("FULL DATA:", data_stats)
            else: 
                def convert(key): return round(100.0 - (100.0 * float(data_stats[key] / float(self.complete_data_stats[key]))), 2)
                print("SAMPLE SIZE: {}% users ; {}% items ; {}% train interactions ; {}% test interactions removed".format(
                    convert('num_users'), convert('num_items'), convert('num_train_interactions'), convert('num_test')
                ))
            with open(path + "/data_stats.json", 'w') as f: json.dump(data_stats, f)

    def load_index(self, path):
        self.index = np.load(path + "/index.npz")['data']
        if self.complete_data_stats is None: self.complete_data_stats = self.measure_data_stats()

    def save_data(self, path):
        flat_data = []
        for u in range(len(self.data)):
            flat_data += list(map(lambda x: [ u ] + x, self.data[u]))
        flat_data = np.array(flat_data)

        shape = [ len(flat_data) ]

        os.makedirs(path, exist_ok = True)
        with h5py.File(path + '/total_data.hdf5', 'w') as file:
            dset = {}
            dset['user'] = file.create_dataset("user", shape, dtype = 'i4', maxshape = shape, compression="gzip")
            dset['item'] = file.create_dataset("item", shape, dtype = 'i4', maxshape = shape, compression="gzip")
            dset['rating'] = file.create_dataset("rating", shape, dtype = 'f', maxshape = shape, compression="gzip")
            dset['time'] = file.create_dataset("time", shape, dtype = 'i4', maxshape = shape, compression="gzip")

            dset['user'][:] = flat_data[:, 0]
            dset['item'][:] = flat_data[:, 1]
            dset['rating'][:] = flat_data[:, 2]
            dset['time'][:] = flat_data[:, 3]

## Download Data

In [None]:
# Where to store the datasets?
!mkdir -p datasets/

# Where to store the logs/models of trained models
!mkdir -p experiments/sampling_runs/results/logs/trained/
!mkdir -p experiments/sampling_runs/results/models/trained/

# Where to store the logs/models of trained proxy models for SVP
!mkdir -p experiments/sampling_runs/results/logs/SVP/
!mkdir -p experiments/sampling_runs/results/models/SVP/

# Base directory for all Data-Genie experiments
!mkdir -p experiments/data_genie/

In [None]:
# # download_amazon_magazine
# !mkdir datasets/magazine/
# link="http://deepyeti.ucsd.edu/jianmo/amazon/categoryFilesSmall/Magazine_Subscriptions.csv"
# !wget $link -P datasets/magazine/ -q --show-progress
# !mv datasets/magazine/Magazine_Subscriptions.csv datasets/magazine/data.csv

# download_ml_100k
link="https://files.grouplens.org/datasets/movielens/ml-100k.zip"
!wget $link -P datasets/ -q --show-progress
!unzip -qq datasets/ml-100k.zip -d datasets/ 
!rm datasets/ml-100k.zip



In [None]:
data_path = './datasets/ml-100k/u.data'

In [None]:
def prep(dataset):	
	f = open(data_path)
	users, items, ratings, time = [], [], [], []

	line = f.readline()
	while line:
		u, i, r, t = line.strip().split('\t')
		users.append(int(u))
		items.append(int(i))
		ratings.append(float(r))
		time.append(int(t))
		line = f.readline()

	min_user = min(users) ; max_user = max(users)
	num_users = len(set(users))

	if min_user == 1:
		assert num_users == max_user
	else:
		assert num_users == max_user + 1

	data = [ [] for _ in range(num_users) ]
	for i in range(len(users)):
		data[users[i] - min_user].append([ items[i], ratings[i], time[i] ])

	# Time sort data
	for i in range(len(data)): 
		data[i].sort(key = lambda x: x[2]) 

	# Shuffling users
	# indices = np.arange(len(data)) ; np.random.shuffle(indices)
	# data = np.array(data)[indices].tolist()

	return rating_data(remap_items(data))

## Preprocess

In [None]:
# from initial_data_prep_code import movielens, amazon, goodreads, beeradvocate
# from data_path_constants import get_data_path
# from svp_handler import SVPHandler

percent_sample = [ 20, 40, 60, 80, 90, 99 ]

# Which datasets to prep?
for dataset in [
	'ml-100k',
    # 'magazine',
	# 'luxury',
	# 'video_games',
	# 'beeradvocate',
	# 'goodreads_comics',
]:

	print("\n\n\n!!!!!!!! STARTED PROCESSING {} !!!!!!!!\n\n\n".format(dataset))

	# if dataset in [ 'ml-100k' ]: total_data = movielens.prep(dataset)
	if dataset in [ 'ml-100k' ]: total_data = prep(dataset)
	elif dataset in [ 'luxury', 'magazine', 'video_games' ]: total_data = amazon.prep(dataset)
	elif dataset in [ 'goodreads_comics' ]: total_data = goodreads.prep(dataset)
	elif dataset in [ 'beeradvocate' ]: total_data = beeradvocate.prep(dataset)

	# Store original data
	total_data.save_data(get_data_path(dataset))

	# Sampling
	# for train_test_split in [ '20_percent_hist', 'leave_2' ]:
	for train_test_split in ['leave_2']:

		total_data.complete_data_stats = None # Since task changed
		path_uptil_now = get_data_path(dataset) + "/" + train_test_split + "/"

		# Make full-data (No sampling)
		total_data.train_test_split(train_test_split)
		print("\n{} split, Overall:".format(train_test_split))
		total_data.save_index(path_uptil_now + "/complete_data/")

		# Frequency sample from user hist (Stratified)
		print("\n{} split, user history random sampling".format(train_test_split))
		for percent in percent_sample:
			total_data.load_index(path_uptil_now + "/complete_data/") # Re-load index map
			total_data.frequency_sample(percent, 0)
			total_data.save_index(path_uptil_now + str(percent) + "_perc_freq_user_rns")

		# Sample users randomly
		print("\n{} split, user random sampling".format(train_test_split))
		for percent in percent_sample:
			total_data.load_index(path_uptil_now + "/complete_data/") # Re-load index map
			total_data.user_random_sample(percent)
			total_data.save_index(path_uptil_now + str(percent) + "_perc_user_rns")

		# Sample interactions randomly
		print("\n{} split, interaction random sampling".format(train_test_split))
		for percent in percent_sample:
			total_data.load_index(path_uptil_now + "/complete_data/") # Re-load index map
			total_data.interaction_random_sample(percent)
			total_data.save_index(path_uptil_now + str(percent) + "_perc_interaction_rns")

		# Temporal sampling
		print("\n{} split, user history temporal sampling".format(train_test_split))
		for percent in percent_sample:
			total_data.load_index(path_uptil_now + "/complete_data/") # Re-load index map
			total_data.temporal_sample(percent)
			total_data.save_index(path_uptil_now + str(percent) + "_perc_temporal")

		# Remove tail users sampling
		print("\n{} split, tail user sampling".format(train_test_split))
		for percent in percent_sample:
			total_data.load_index(path_uptil_now + "/complete_data/") # Re-load index map
			total_data.tail_user_remove(percent)
			total_data.save_index(path_uptil_now + str(percent) + "_perc_tail_user_remove")

		# Pagerank based sampling
		print("\n{} split, pagerank sampling".format(train_test_split))
		for percent in percent_sample:
			total_data.load_index(path_uptil_now + "/complete_data/") # Re-load index map
			total_data.pagerank_sample(percent)
			total_data.save_index(path_uptil_now + str(percent) + "_perc_pagerank")

		# RW based sampling
		print("\n{} split, random walk sampling".format(train_test_split))
		for percent in percent_sample:
			total_data.load_index(path_uptil_now + "/complete_data/") # Re-load index map
			total_data.random_walk_sample(percent)
			total_data.save_index(path_uptil_now + str(percent) + "_perc_random_walk")

		# Forest-fire based sampling
		print("\n{} split, forest fire sampling".format(train_test_split))
		for percent in percent_sample:
			total_data.load_index(path_uptil_now + "/complete_data/") # Re-load index map
			total_data.forest_fire_sample(percent)
			total_data.save_index(path_uptil_now + str(percent) + "_perc_forest_fire")

		# Sample interactions according to SVP
		hyper_params = {}
		hyper_params['dataset'] = dataset
		hyper_params['sampling'] = 'complete_data' # While training the proxy model

		for proxy_model in [ 'bias_only', 'MF_dot' ]:
			scenarios = [ 'sequential' ] if train_test_split == 'leave_2' else [ 'implicit', 'explicit' ]

			for loss_type in scenarios:
				print() ; svp_handler = SVPHandler(proxy_model, loss_type, hyper_params)

				for sampling in [ 
					'forgetting_events', 
					'forgetting_events_propensity',
					'forgetting_events_user', 
					'forgetting_events_user_propensity',
				]:
					print("\n{} split, SVP: {}_{}, {} loss".format(train_test_split, proxy_model, sampling, loss_type))
					for percent in percent_sample:
						total_data.load_index(path_uptil_now + "/complete_data/") # Re-load index map
						total_data.svp_sample(percent, svp_handler, sampling)
						total_data.save_index(path_uptil_now + "svp_{}_{}/{}_perc_{}".format(proxy_model, loss_type, percent, sampling))




!!!!!!!! STARTED PROCESSING ml-100k !!!!!!!!




leave_2 split, Overall:
FULL DATA: {'num_users': 943, 'num_items': 1682, 'num_train_interactions': 98114, 'num_test': 943, 'num_val': 943}

leave_2 split, user history random sampling
SAMPLE SIZE: 0.0% users ; 1.9% items ; 20.38% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 3.69% items ; 40.39% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 8.32% items ; 60.38% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 15.76% items ; 80.39% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 22.89% items ; 90.43% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 47.15% items ; 99.41% train interactions ; 0.0% test interactions removed

leave_2 split, user random sampling
SAMPLE SIZE: 0.0% users ; 0.71% items ; 20.01% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 3.98% 

100%|██████████| 96/96 [00:00<00:00, 370.27it/s]


Saving model...
-----------------------------------------------------------------------------------------
| end of epoch 1 | time =  0.46 | eval = partial | HR@10 = 43.9024 | HR@100 = 100.0 | NDCG@10 = 23.3166 | NDCG@100 = 35.1766 | PSP@10 = 43.9024 | PSP@100 = 100.0 | AUC = 0.807 | dataset = ml-100k (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 371.47it/s]


-----------------------------------------------------------------------------------------
| end of epoch 2 | time =  0.28 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 120.76it/s]


-----------------------------------------------------------------------------------------
| end of epoch 3 | time =  0.82 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 213.31it/s]


-----------------------------------------------------------------------------------------
| end of epoch 4 | time =  0.48 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 216.19it/s]


-----------------------------------------------------------------------------------------
| end of epoch 5 | time =  0.46 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 419.80it/s]


-----------------------------------------------------------------------------------------
| end of epoch 6 | time =  0.25 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 194.11it/s]


-----------------------------------------------------------------------------------------
| end of epoch 7 | time =  0.51 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 211.21it/s]


-----------------------------------------------------------------------------------------
| end of epoch 8 | time =  0.47 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 342.73it/s]


-----------------------------------------------------------------------------------------
| end of epoch 9 | time =  0.29 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 205.96it/s]


-----------------------------------------------------------------------------------------
| end of epoch 10 | time =  0.48 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 563.66it/s]


-----------------------------------------------------------------------------------------
| end of epoch 11 | time =  0.19 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 178.94it/s]


-----------------------------------------------------------------------------------------
| end of epoch 12 | time =  0.55 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 196.53it/s]


-----------------------------------------------------------------------------------------
| end of epoch 13 | time =  0.51 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 396.66it/s]


-----------------------------------------------------------------------------------------
| end of epoch 14 | time =  0.27 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 196.47it/s]


-----------------------------------------------------------------------------------------
| end of epoch 15 | time =  0.52 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 175.90it/s]


-----------------------------------------------------------------------------------------
| end of epoch 16 | time =  0.58 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 563.67it/s]


-----------------------------------------------------------------------------------------
| end of epoch 17 | time =  0.20 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 194.86it/s]


-----------------------------------------------------------------------------------------
| end of epoch 18 | time =  0.51 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 194.51it/s]


-----------------------------------------------------------------------------------------
| end of epoch 19 | time =  0.51 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 431.75it/s]


-----------------------------------------------------------------------------------------
| end of epoch 20 | time =  0.24 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 192.40it/s]


-----------------------------------------------------------------------------------------
| end of epoch 21 | time =  0.51 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 187.05it/s]


-----------------------------------------------------------------------------------------
| end of epoch 22 | time =  0.53 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 457.61it/s]


-----------------------------------------------------------------------------------------
| end of epoch 23 | time =  0.24 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 212.24it/s]


-----------------------------------------------------------------------------------------
| end of epoch 24 | time =  0.47 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 223.01it/s]


-----------------------------------------------------------------------------------------
| end of epoch 25 | time =  0.45 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 487.00it/s]


-----------------------------------------------------------------------------------------
| end of epoch 26 | time =  0.22 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 207.09it/s]


-----------------------------------------------------------------------------------------
| end of epoch 27 | time =  0.49 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 226.52it/s]


-----------------------------------------------------------------------------------------
| end of epoch 28 | time =  0.45 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 511.93it/s]


-----------------------------------------------------------------------------------------
| end of epoch 29 | time =  0.22 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 210.63it/s]


-----------------------------------------------------------------------------------------
| end of epoch 30 | time =  0.48 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 211.43it/s]


-----------------------------------------------------------------------------------------
| end of epoch 31 | time =  0.48 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 383.23it/s]


-----------------------------------------------------------------------------------------
| end of epoch 32 | time =  0.28 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 237.53it/s]


-----------------------------------------------------------------------------------------
| end of epoch 33 | time =  0.43 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 520.94it/s]


-----------------------------------------------------------------------------------------
| end of epoch 34 | time =  0.21 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 216.69it/s]


-----------------------------------------------------------------------------------------
| end of epoch 35 | time =  0.47 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 224.29it/s]


-----------------------------------------------------------------------------------------
| end of epoch 36 | time =  0.45 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 391.15it/s]


-----------------------------------------------------------------------------------------
| end of epoch 37 | time =  0.26 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 215.83it/s]


-----------------------------------------------------------------------------------------
| end of epoch 38 | time =  0.46 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 205.35it/s]


-----------------------------------------------------------------------------------------
| end of epoch 39 | time =  0.48 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 403.35it/s]


-----------------------------------------------------------------------------------------
| end of epoch 40 | time =  0.26 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 227.18it/s]


-----------------------------------------------------------------------------------------
| end of epoch 41 | time =  0.44 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 214.41it/s]


-----------------------------------------------------------------------------------------
| end of epoch 42 | time =  0.46 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 514.64it/s]


-----------------------------------------------------------------------------------------
| end of epoch 43 | time =  0.21 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 203.84it/s]


-----------------------------------------------------------------------------------------
| end of epoch 44 | time =  0.49 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 197.37it/s]


-----------------------------------------------------------------------------------------
| end of epoch 45 | time =  0.51 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 749.53it/s]


-----------------------------------------------------------------------------------------
| end of epoch 46 | time =  0.14 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 180.49it/s]


-----------------------------------------------------------------------------------------
| end of epoch 47 | time =  0.54 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 214.98it/s]


-----------------------------------------------------------------------------------------
| end of epoch 48 | time =  0.47 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 547.20it/s]


-----------------------------------------------------------------------------------------
| end of epoch 49 | time =  0.20 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 210.92it/s]


-----------------------------------------------------------------------------------------
| end of epoch 50 | time =  0.49 (VAL)
-----------------------------------------------------------------------------------------

leave_2 split, SVP: bias_only_forgetting_events, sequential loss
SAMPLE SIZE: 0.0% users ; 0.0% items ; 20.0% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 0.0% items ; 40.0% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 0.12% items ; 60.0% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 1.37% items ; 80.0% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 4.64% items ; 90.0% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 27.35% items ; 99.0% train interactions ; 0.0% test interactions removed

leave_2 split, SVP: bias_only_forgetting_events_propensity, sequential loss
SAMPLE SIZE: 0.0% users ; 0.0% items ; 20.0% train 

100%|██████████| 96/96 [00:00<00:00, 225.61it/s]


Saving model...
-----------------------------------------------------------------------------------------
| end of epoch 1 | time =  0.62 | eval = partial | HR@10 = 44.3266 | HR@100 = 100.0 | NDCG@10 = 23.8439 | NDCG@100 = 35.5882 | PSP@10 = 44.3266 | PSP@100 = 100.0 | AUC = 0.8074 | dataset = ml-100k (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 127.30it/s]


-----------------------------------------------------------------------------------------
| end of epoch 2 | time =  0.77 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 194.73it/s]


-----------------------------------------------------------------------------------------
| end of epoch 3 | time =  0.52 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 244.25it/s]


-----------------------------------------------------------------------------------------
| end of epoch 4 | time =  0.41 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 169.38it/s]


-----------------------------------------------------------------------------------------
| end of epoch 5 | time =  0.58 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 259.09it/s]


-----------------------------------------------------------------------------------------
| end of epoch 6 | time =  0.39 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 206.67it/s]


-----------------------------------------------------------------------------------------
| end of epoch 7 | time =  0.49 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 145.65it/s]


-----------------------------------------------------------------------------------------
| end of epoch 8 | time =  0.68 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 282.39it/s]


-----------------------------------------------------------------------------------------
| end of epoch 9 | time =  0.35 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 203.67it/s]


-----------------------------------------------------------------------------------------
| end of epoch 10 | time =  0.49 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 145.79it/s]


-----------------------------------------------------------------------------------------
| end of epoch 11 | time =  0.67 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 154.72it/s]


-----------------------------------------------------------------------------------------
| end of epoch 12 | time =  0.63 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 248.68it/s]


-----------------------------------------------------------------------------------------
| end of epoch 13 | time =  0.40 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 259.79it/s]


-----------------------------------------------------------------------------------------
| end of epoch 14 | time =  0.40 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 275.91it/s]


-----------------------------------------------------------------------------------------
| end of epoch 15 | time =  0.37 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 209.08it/s]


-----------------------------------------------------------------------------------------
| end of epoch 16 | time =  0.48 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 286.43it/s]


-----------------------------------------------------------------------------------------
| end of epoch 17 | time =  0.35 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 210.56it/s]


-----------------------------------------------------------------------------------------
| end of epoch 18 | time =  0.47 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 283.16it/s]


-----------------------------------------------------------------------------------------
| end of epoch 19 | time =  0.35 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 211.28it/s]


-----------------------------------------------------------------------------------------
| end of epoch 20 | time =  0.47 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 247.13it/s]


-----------------------------------------------------------------------------------------
| end of epoch 21 | time =  0.41 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 205.28it/s]


-----------------------------------------------------------------------------------------
| end of epoch 22 | time =  0.49 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 159.45it/s]


-----------------------------------------------------------------------------------------
| end of epoch 23 | time =  0.61 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 166.85it/s]


-----------------------------------------------------------------------------------------
| end of epoch 24 | time =  0.59 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 261.03it/s]


-----------------------------------------------------------------------------------------
| end of epoch 25 | time =  0.39 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 248.74it/s]


-----------------------------------------------------------------------------------------
| end of epoch 26 | time =  0.40 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 292.32it/s]


-----------------------------------------------------------------------------------------
| end of epoch 27 | time =  0.35 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 188.26it/s]


-----------------------------------------------------------------------------------------
| end of epoch 28 | time =  0.52 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 169.26it/s]


-----------------------------------------------------------------------------------------
| end of epoch 29 | time =  0.59 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 159.14it/s]


-----------------------------------------------------------------------------------------
| end of epoch 30 | time =  0.62 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 291.42it/s]


-----------------------------------------------------------------------------------------
| end of epoch 31 | time =  0.34 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 181.98it/s]


-----------------------------------------------------------------------------------------
| end of epoch 32 | time =  0.54 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 140.66it/s]


-----------------------------------------------------------------------------------------
| end of epoch 33 | time =  0.71 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 159.85it/s]


-----------------------------------------------------------------------------------------
| end of epoch 34 | time =  0.61 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 256.61it/s]


-----------------------------------------------------------------------------------------
| end of epoch 35 | time =  0.39 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 246.01it/s]


-----------------------------------------------------------------------------------------
| end of epoch 36 | time =  0.41 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 278.98it/s]


-----------------------------------------------------------------------------------------
| end of epoch 37 | time =  0.36 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 189.22it/s]


-----------------------------------------------------------------------------------------
| end of epoch 38 | time =  0.53 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 290.66it/s]


-----------------------------------------------------------------------------------------
| end of epoch 39 | time =  0.35 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 214.90it/s]


-----------------------------------------------------------------------------------------
| end of epoch 40 | time =  0.46 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 165.54it/s]


-----------------------------------------------------------------------------------------
| end of epoch 41 | time =  0.61 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 156.31it/s]


-----------------------------------------------------------------------------------------
| end of epoch 42 | time =  0.63 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 145.61it/s]


-----------------------------------------------------------------------------------------
| end of epoch 43 | time =  0.68 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 266.31it/s]


-----------------------------------------------------------------------------------------
| end of epoch 44 | time =  0.38 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 215.99it/s]


-----------------------------------------------------------------------------------------
| end of epoch 45 | time =  0.45 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 152.49it/s]


-----------------------------------------------------------------------------------------
| end of epoch 46 | time =  0.65 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 148.62it/s]


-----------------------------------------------------------------------------------------
| end of epoch 47 | time =  0.66 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 153.52it/s]


-----------------------------------------------------------------------------------------
| end of epoch 48 | time =  0.64 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 136.31it/s]


-----------------------------------------------------------------------------------------
| end of epoch 49 | time =  0.73 (VAL)
-----------------------------------------------------------------------------------------


100%|██████████| 96/96 [00:00<00:00, 163.67it/s]


-----------------------------------------------------------------------------------------
| end of epoch 50 | time =  0.61 (VAL)
-----------------------------------------------------------------------------------------

leave_2 split, SVP: MF_dot_forgetting_events, sequential loss
SAMPLE SIZE: 0.0% users ; 0.0% items ; 20.0% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 0.0% items ; 40.0% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 0.0% items ; 60.0% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 0.12% items ; 80.0% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 1.01% items ; 90.0% train interactions ; 0.0% test interactions removed
SAMPLE SIZE: 0.0% users ; 26.4% items ; 99.0% train interactions ; 0.0% test interactions removed

leave_2 split, SVP: MF_dot_forgetting_events_propensity, sequential loss
SAMPLE SIZE: 0.0% users ; 0.0% items ; 20.0% train interact

In [None]:
!tree --du -h -C ./datasets/ml-100k

[01;34m./datasets/ml-100k[00m
├── [5.6M]  [01;34m20_percent_hist[00m
│   ├── [ 42K]  [01;34m20_perc_forest_fire[00m
│   │   ├── [ 106]  data_stats.json
│   │   └── [ 38K]  index.npz
│   ├── [ 47K]  [01;34m20_perc_freq_user_rns[00m
│   │   ├── [ 106]  data_stats.json
│   │   └── [ 43K]  index.npz
│   ├── [ 47K]  [01;34m20_perc_interaction_rns[00m
│   │   ├── [ 106]  data_stats.json
│   │   └── [ 42K]  index.npz
│   ├── [ 44K]  [01;34m20_perc_pagerank[00m
│   │   ├── [ 106]  data_stats.json
│   │   └── [ 39K]  index.npz
│   ├── [ 44K]  [01;34m20_perc_random_walk[00m
│   │   ├── [ 106]  data_stats.json
│   │   └── [ 40K]  index.npz
│   ├── [ 33K]  [01;34m20_perc_tail_user_remove[00m
│   │   ├── [ 106]  data_stats.json
│   │   └── [ 29K]  index.npz
│   ├── [ 35K]  [01;34m20_perc_temporal[00m
│   │   ├── [ 106]  data_stats.json
│   │   └── [ 31K]  index.npz
│   ├── [ 33K]  [01;34m20_perc_user_rns[00m
│   │   ├── [ 106]  data_stats.json
│   │   └── [ 29K]  index.npz
│   ├

## References
- [https://arxiv.org/abs/2201.04768v1](https://arxiv.org/abs/2201.04768v1)
- [https://github.com/noveens/sampling_cf](https://github.com/noveens/sampling_cf)

## Theory
| Sampling Strategy | What is sampled? |
| --- | --- |
| Random | Interactions |
| Stratified | Interactions |
| Temporal | Interactions |
| SVP-CF w/ MF | Interactions |
| SVP-CF w/ Bias-only | Interactions |
| SVP-CF-Prop w/ MF | Interactions |
| SVP-CF-Prop w/ Bias-only | Interactions |
| Random | Users |
| Head | Users |
| SVP-CF w/ MF | Users |
| SVP-CF w/ Bias-only | Users |
| SVP-CF-Prop w/ MF | Users |
| SVP-CF-Prop w/ Bias-only | Users |
| Centrality | Graph |
| Random-Walk | Graph |
| Forest-Fire | Graph |

### Interaction sampling

- In Random Interaction Sampling, we generate D𝑠,𝑝 by randomly sampling 𝑝% of all the user-item interactions in D.
- User-history Stratified Sampling is another popular sampling technique to generate smaller CF-datasets. To match the user-frequency distribution amongst D and D𝑠,𝑝 , it randomly samples 𝑝% of interactions from each user’s consumption history.
- Unlike random stratified sampling, User-history Temporal Sampling samples 𝑝% of the most recent interactions for each user. This strategy is representative of the popular practice of making data subsets from the online traffic of the last 𝑥 days.

### User sampling

*To ensure a fair comparison amongst the different kinds of sampling schemes, we retain exactly 𝑝% of the total interactions in D𝑠,𝑝 .*

- In Random User Sampling, we retain users from D at random. To be more specific, we iteratively preserve all the interactions for a random user until we have retained 𝑝% of the original interactions.
- Another strategy we employ is Head User Sampling, in which we iteratively remove the user with the least amount of total interactions. This method is representative of commonly used data pre-processing strategies to make data suitable for parameter-heavy algorithms. Sampling the data in such a way can introduce bias toward users from minority groups which might raise concerns from a diversity and fairness perspective.

### Graph sampling

- In Centrality-based Sampling, we proceed by computing the pagerank centrality scores for each node in G, and retain all the edges (interactions) of the top scoring nodes until a total 𝑝% of the original interactions have been preserved.
- Another popular strategy we employ is Random-walk Sampling, which performs multiple random-walks with restart on G and retains the edges amongst those pairs of nodes that have been visited at least once. We keep expanding our walk until 𝑝% of the initial edges have been retained.
- We also utilize Forest-fire Sampling, which is a snowball sampling method and proceeds by randomly “burning” the outgoing edges of visited nodes. It initially starts with a random node, and then propagates to a random subset of previously unvisited neighbors. The propagation is terminated once we have created a graph-subset with 𝑝% of the initial edges.

### SVP-CF: Selection-Via-Proxy for CF data

Irrespective of whether to sample users or interactions, SVPCF proceeds by training an inexpensive proxy model P on the full, original data D and modifies the forgetting-events approach to retain the points with the highest importance.