In [1]:
import torch
import requests
from tqdm import tqdm
from torch import nn
from torch.nn import Embedding, Linear, Bilinear, BatchNorm1d, ReLU, Dropout, MarginRankingLoss
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset

In [2]:
class PosNegData:

    def __init__(self, pos_data, neg_data, weight):
        self.pos = pos_data
        self.neg = neg_data
        self.weight = weight


class Data:

    def __init__(self, user_id, item_id, metadata):
        self.user_id = torch.tensor(user_id)
        self.item_id = torch.tensor(item_id)
        self.metadata = torch.tensor(metadata)


class DataGenerator(Dataset):

    def __init__(self, state_history, reward_history, action_history):
        self.state_history = state_history
        self.reward_history = reward_history
        self.action_history = action_history
        self.data = []
        self._init_pos_neg()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def _init_pos_neg(self):
        for i, r in enumerate(self.reward_history):
            if r > 0:
                user_id = self.state_history[i][0][0]
                action = self.action_history[i]
                pos_data = Data(user_id=user_id, item_id=self.state_history[i][action][1],
                                metadata=self.state_history[i][action][2:])
                for j, state in enumerate(self.state_history[i]):
                    item_id = state[1]
                    metadata = state[2:]
                    data = Data(user_id=user_id, item_id=item_id, metadata=metadata)
                    if j != action:
                        self.data.append(PosNegData(pos_data, data, 1))

    def add_data(self, state, action, reward):
        if reward > 0:
            user_id = state[0][0]
            pos_data = Data(user_id=user_id, item_id=state[action][1], metadata=state[action][2:])
            for j, my_state in enumerate(state):
                item_id = my_state[1]
                metadata = my_state[2:]
                data = Data(user_id=user_id, item_id=item_id, metadata=metadata)
                if j != action:
                    self.data.append(PosNegData(pos_data, data, 1))


def collate_data_pos_neg(list_of_data):
    raw_data = [data for data in list_of_data]
    user_id_pos = torch.stack([data.pos.user_id for data in list_of_data])
    item_id_pos = torch.stack([data.pos.item_id for data in list_of_data])
    metadata_pos = torch.stack([data.pos.metadata for data in list_of_data])
    user_id_neg = torch.stack([data.neg.user_id for data in list_of_data])
    item_id_neg = torch.stack([data.neg.item_id for data in list_of_data])
    metadata_neg = torch.stack([data.neg.metadata for data in list_of_data])
    return {'user_id_pos': user_id_pos, 'item_id_pos': item_id_pos, 'metadata_pos': metadata_pos, 'raw_data': raw_data,
            'user_id_neg': user_id_neg, 'item_id_neg': item_id_neg, 'metadata_neg': metadata_neg}


def collate_data(list_of_data):
    user_id = torch.stack([data.user_id for data in list_of_data])
    item_id = torch.stack([data.item_id for data in list_of_data])
    metadata = torch.stack([data.metadata for data in list_of_data])
    return {'user_id': user_id, 'item_id': item_id, 'metadata': metadata}

In [3]:
class Interface:

    def __init__(self, args):
        self.base_url = 'http://{}'.format(args.ip_address_env_2)
        self.user_id = args.user_id
        self.url_reset = '{}/reset'.format(self.base_url)
        self.url_predict = '{}/predict'.format(self.base_url)

        r = requests.get(url=self.url_reset, params={'user_id': self.user_id})
        data = r.json()
        self.state_history = data['state_history']
        self.rewards_history = data['rewards_history']
        self.action_history = data['action_history']

        self.nb_items = data['nb_items']
        self.nb_users = data['nb_users']
        self.nb_variables = len(self.state_history[0][0]) - 2

        self.next_state = data['next_state']

    def reset(self):
        r = requests.get(url=self.url_reset, params={'user_id': self.user_id})
        data = r.json()

        self.state_history = data['state_history']
        self.rewards_history = data['rewards_history']
        self.action_history = data['action_history']

        self.nb_items = data['nb_items']
        self.nb_users = data['nb_users']

        self.next_state = data['next_state']

    def predict(self, recommended_item):
        r = requests.get(url=self.url_predict, params={'user_id': self.user_id, 'recommended_item': recommended_item})
        data = r.json()

        self.state_history.append(data['state'])
        self.rewards_history.append(data['reward'])
        self.action_history.append(recommended_item)

        self.next_state = data['state']
        return data['state'], data['reward']

In [4]:
class SiameseNetwork(nn.Module):

    def __init__(self, interface):
        super(SiameseNetwork, self).__init__()

        user_embedding_dim = 10
        item_embedding_dim = 10
        user_meta_dim = 15
        item_meta_dim = 15
        meta_meta_dim = 30
        dense_1_dim = 32
        dense_2_dim = 15
        out_dim = 1

        self.embedding_user = Embedding(num_embeddings=interface.nb_users, embedding_dim=user_embedding_dim)
        self.embedding_item = Embedding(num_embeddings=interface.nb_items, embedding_dim=item_embedding_dim)
        self.concat_user_meta = Bilinear(in1_features=user_embedding_dim, in2_features=interface.nb_variables, out_features=user_meta_dim)
        self.concat_item_meta = Bilinear(in1_features=item_embedding_dim, in2_features=interface.nb_variables, out_features=item_meta_dim)
        self.concat_meta_meta = Bilinear(in1_features=user_meta_dim, in2_features=item_meta_dim, out_features=meta_meta_dim)
        self.batch_norm_0 = BatchNorm1d(num_features=meta_meta_dim)
        self.dropout_0 = Dropout(0.5)
        self.dense_1 = Linear(in_features=meta_meta_dim, out_features=dense_1_dim)
        self.relu_1 = ReLU()
        self.dropout_1 = Dropout(0.5)
        self.batch_norm_1 = BatchNorm1d(num_features=dense_1_dim)
        self.dense_2 = Linear(in_features=dense_1_dim, out_features=dense_2_dim)
        self.relu_2 = ReLU()
        self.dropout_2 = Dropout(0.5)
        self.batch_norm_2 = BatchNorm1d(num_features=dense_2_dim)
        self.dense_3 = Linear(in_features=dense_2_dim, out_features=out_dim)

    def forward(self, user_id, item_id, metadata):
        user_embedded = self.embedding_user(user_id).squeeze(dim=1)
        item_embedded = self.embedding_item(item_id).squeeze(dim=1)
        user_and_meta = self.concat_user_meta(user_embedded, metadata)
        item_and_meta = self.concat_item_meta(item_embedded, metadata)
        meta_and_meta = self.concat_meta_meta(user_and_meta, item_and_meta)
        output = self.batch_norm_0(meta_and_meta)
        output = self.dropout_0(output)
        output = self.dense_1(output)
        output = self.relu_1(output)
        output = self.batch_norm_1(output)
        output = self.dropout_1(output)
        output = self.dense_2(output)
        output = self.relu_2(output)
        output = self.batch_norm_2(output)
        output = self.dropout_2(output)
        output = self.dense_3(output)
        return output

In [12]:
class Trainer:

    def __init__(self, interface, learning_rate=3e-4, validation_split=0.2, batch_size=32, margin=10, min_weight=1,
                 num_samples=100):
        self.interface = interface
        self.network = SiameseNetwork(interface)
        self.dataset = DataGenerator(interface.state_history, interface.rewards_history, interface.action_history)
        self.batch_size = batch_size
        self.validation_split = validation_split
        self.min_weight = min_weight
        self.num_samples = num_samples

        self.loss = MarginRankingLoss(margin=margin, reduction='none')

        self.optimizer = Adam(self.network.parameters(), lr=learning_rate)
        self.lr_scheduler = ReduceLROnPlateau(self.optimizer, factor=0.3, patience=5, threshold=1e-3, verbose=True)

    def reset(self):
        self.train()
        for k in range(100):
            self.online()

    def train(self):
        weights = [data.weight for data in self.dataset]
        sampler = WeightedRandomSampler(weights=weights, num_samples=self.num_samples, replacement=True)
        data_loader = DataLoader(self.dataset, batch_size=self.batch_size, sampler=sampler,
                                 collate_fn=collate_data_pos_neg, drop_last=True)
        self.network.train()
        cumloss = 0
        for inputs in data_loader:
            self.optimizer.zero_grad()
            output_pos = self.network(inputs['user_id_pos'], inputs['item_id_pos'], inputs['metadata_pos'])
            output_neg = self.network(inputs['user_id_neg'], inputs['item_id_neg'], inputs['metadata_neg'])
            loss = self.loss(output_pos, output_neg, torch.ones(output_pos.shape))
            for j, data in enumerate(inputs['raw_data']):
                data.weight = loss[j][0].item()
            cumloss += loss.sum().item()
            loss = loss.mean()
            loss.backward()
            self.optimizer.step()
        print(cumloss / len(self.dataset))

    def online(self):
        self.network.eval()
        l = []
        my_state = self.interface.next_state
        for m in self.interface.next_state:
            data = Data(m[0], m[1], m[2:])
            l.append(data)
        input = collate_data(l)
        output = self.network(input['user_id'], input['item_id'], input['metadata']).squeeze()
        recommended_item = output.argmax().item()
        state, reward = self.interface.predict(recommended_item)
        self.dataset.add_data(my_state, recommended_item, reward)
        self.train()

In [13]:
class Argument:
    pass

args = Argument
args.user_id = 'R3EIFXNYY6XMBXBR01BK'
args.ip_address_env_0 = '52.47.62.31'
args.ip_address_env_1 = '35.180.254.42'
args.ip_address_env_2 = '35.180.178.243'

interface = Interface(args)
trainer = Trainer(interface)
interface.reset()
trainer.reset()

0.5735606667524327
0.5701457676252852
0.5648504106681114
0.5628334288041078
0.5674541999089002
0.561578685459456
0.5576747061547752
0.5565233434910158
0.5573690002476748
0.5517563634893421
0.5550728661515634
0.5582864360444865
0.5557974598348382
0.5439000838842147
0.5497409396701389
0.5435412284655449
0.5507863797353544
0.5486463410901887
0.5469788119324253
0.5505961100260417
0.5473346938434829
0.5382688719060923
0.5304984704261403
0.5393281603073325
0.5388333444638103
0.5365355795274401
0.5412022030407003
0.5386790459466088
0.5368348194344696
0.5404612622453493
0.5396469697823973
0.5402041550708994
0.5308701546232936
0.5332542288980626
0.521081151972635
0.5127983930086694
0.5074187460639077
0.500336485612182
0.49078475397425436
0.4908943547387033
0.4900552229092662
0.4900479568192364
0.4910404734653018
0.48598662390364716
0.48213806614497884
0.473140431873834
0.47856830957517105
0.4807998733140936
0.46887802715217325
0.46833323075994665
0.4685540080012031
0.47244644352125736
0.4723587