# Assignment 3 | Main notebook 
---

In [183]:
import pandas as pd
import numpy as np
from files import p2vmap
import yaml
import torch  # NN -> our own P2V model
import torch.nn
import torch.nn.functional
import os

## load files

In [2]:
baskets = pd.read_parquet('files/market-baskets.parquet')
del baskets["customer"]

In [3]:
with open('files/config-w2v.yaml') as file:
    config_w2v = yaml.full_load(file)
    
with open('files/config-p2vmap.yaml') as file:
    config_p2vmap = yaml.full_load(file)

In [5]:
basket_list = p2vmap.baskets_df_to_list(x=baskets.head(1_000_000), **config_w2v["baskets_df_to_list"])

## Pipeline

In [None]:
baskets_2 = baskets.head(1_000)
baskets_2 = baskets_2[baskets_2.groupby("basket")["product"].transform("nunique") > 2]

data_stream_p2v = p2vmap.DataStreamP2V(
    data=baskets_2, **config_p2vmap["data"]["data_streamer"]
)

dl_train, dl_validation = p2vmap.build_data_loader(
    streamer=data_stream_p2v,
    config=config_p2vmap["data"],
)

## Functions for classes

In [6]:
def init_wi(n_products, size):
    hd1 = 0.025
    wi = torch.nn.Embedding(n_products, size, sparse=True)
    with torch.no_grad():
        wi.weight.uniform_(-hd1, hd1)
    return wi
    
def init_wo(n_products, size):
    hd1 = 0.025
    wo = torch.nn.Embedding(n_products, size, sparse=True)
    with torch.no_grad():
        wo.weight.uniform_(-hd1, hd1)
    return wo    

In [161]:
def generate_random_batch_of_indexes(dl=dl_train):
    for ce, co, ns in dl_train:
        ce, co, ns
    return ce, co, ns

In [162]:
def initialize_wi(n_products, size):
    hd1 = 0.025
    wi = torch.nn.Embedding(n_products, size, sparse=True)
    with torch.no_grad():
        wi.weight.uniform_(-hd1, hd1)
    return wi

def initialize_wo(n_products, size):
    hd1 = 0.025
    wo = torch.nn.Embedding(n_products, size, sparse=True)
    with torch.no_grad():
        wo.weight.uniform_(-hd1, hd1)
    return wo    

In [163]:
def get_vectors_from_indexes(wi, wo, ce, co, ns):
    wi_center = wi(ce)
    wo_positive_samples = wo(co)
    wo_negative_samples = wo(ns)
    return wi_center, wo_positive_samples, wo_negative_samples

In [174]:
def get_loss(wi_center, wo_positive_samples, wo_negative_samples):

    logits_positive_samples = torch.einsum("ij,ij->i", (wi_center, wo_positive_samples))
    logits_negative_samples = torch.einsum("ik,ijk->ij", (wi_center, wo_negative_samples))

    loss_positive_samples = torch.nn.functional.binary_cross_entropy_with_logits(
        input=logits_positive_samples,
        target=torch.ones_like(logits_positive_samples),
        reduction="sum",
    )

    loss_negative_samples = torch.nn.functional.binary_cross_entropy_with_logits(
        input=logits_negative_samples,
        target=torch.zeros_like(logits_negative_samples),
        reduction="sum",
    )

    n_samples = logits_positive_samples.shape[0] * (logits_negative_samples.shape[1] + 1)
    main_loss = (loss_positive_samples + loss_negative_samples) / n_samples
    return main_loss

## Full implementation

In [181]:
n_products = 300
size = 20

ce, co, ns = generate_random_batch_of_indexes()
wi, wo = initialize_wi(n_products,size), initialize_wo(n_products,size)
wi_center, wo_positive_samples, wo_negative_samples = get_vectors_from_indexes(wi,wo,ce,co,ns)
loss = get_loss(wi_center, wo_positive_samples, wo_negative_samples)

## CLASSES

In [184]:
class P2V(torch.nn.Module):
    
    def __init__(self, n_products, size, batch_size, n_negative_samples):
        super().__init__()

        # add trainable variables here
        self.wi = init_wi(n_products, size)
        self.wo = init_wo(n_products, size)   
             
    def forward(self, center, context, negative_samples):
        wi_center, wo_positive_samples, wo_negative_samples = get_vectors_from_indexes(wi,wo,center,context,negative_samples)
        return get_loss(wi_center, wo_positive_samples, wo_negative_samples)

In [None]:
class TrainerP2V:
    
    def __init__(self, model, train, validation, path, n_batch_log=500):
        
        self.model = model  # put your model here (via __init__ method)
        self.train = train
        self.validation = validation
        self.optimizer = torch.optim.SparseAdam(params=list(model.parameters()))
        self.path = path
        os.makedirs(f"{path}/weights")
        self.writer_train = torch.utils.tensorboard.SummaryWriter(
            f"{self.path}/runs/train"
        )
        self.writer_val = torch.utils.tensorboard.SummaryWriter(f"{self.path}/runs/val")
        self.n_batch_log = n_batch_log
        self.global_batch = 0
        self.epoch = 0
        self.batch = 0

    def fit(self, n_epochs):

        for _ in range(n_epochs):
            print(f"epoch = {self.epoch}")

            for ce, co, ns in self.train:
                self.batch += 1
                self.global_batch += 1

                # add training steps here:
                #  - reset gradients
                #  - model (forward) pass
                #  - compute gradients
                #  - optimizer step

                self.writer_train.add_scalar("loss", loss_train, self.global_batch)

                if self.batch % self.n_batch_log == 1:
                    self._callback_batch()

            self._callback_epoch()
            self.epoch += 1

        self.writer_train.flush()
        self.writer_train.close()
        self.writer_val.flush()
        self.writer_val.close()

    def _callback_batch(self):
        # validation loss
        self.model.eval()
        with torch.no_grad():
            list_loss_validation = []
            for ce, co, ns in self.validation:
                list_loss_validation.append(self.model(ce, co, ns).item())
            loss_validation = np.mean(list_loss_validation)
        self.writer_val.add_scalar("loss", loss_validation, self.global_batch)
        self.model.train()

        # save weights
        np.save(
            f"{self.path}/weights/wi_{self.epoch:02d}_{self.batch:06d}.npy",
            self.get_wi(),
        )

        np.save(
            f"{self.path}/weights/wo_{self.epoch:02d}_{self.batch:06d}.npy",
            self.get_wo(),
        )

    def get_wi(self):
        return self.model.wi.weight.detach().numpy()

    def get_wo(self):
        return self.model.wo.weight.detach().numpy()