In [1]:
import numpy as np
import torch
from scipy import sparse as sp
from torch import nn
from torch.nn import functional as F
from torch.optim import Optimizer, Adam
from tqdm import trange

from src.recommender_model import RecommenderModel
from src.utils import train_model, evaluate_model, plot_losses

# MultVAE
This notebook provides an implementation of the MultVAE model. See [Variational Autoencoders for Collaborative Filtering, Dawen Liang](https://arxiv.org/pdf/1802.05814).

In [2]:
class VAE(nn.Module):
	def __init__(self, input_dim: int, hidden_dims: list[int] = None, latent_dim: int = 256, dropout: float = .5):
		super(VAE, self).__init__()

		activation = nn.Tanh()
		self.latent_dim = latent_dim

		if hidden_dims is None:
			hidden_dims = [1024]

		encoder_layers = []
		prev_dim = input_dim
		for hidden_dim in hidden_dims:
			encoder_layers.extend([
				nn.Linear(prev_dim, hidden_dim),
				activation,
			])
			prev_dim = hidden_dim
		self.encoder = nn.Sequential(*encoder_layers)

		self.distribution_parameters = nn.Linear(hidden_dims[-1], 2 * latent_dim)

		decoder_layers = []
		prev_dim = latent_dim
		for hidden_dim in reversed(hidden_dims):
			decoder_layers.extend([
				nn.Linear(prev_dim, hidden_dim),
				activation,
			])
			prev_dim = hidden_dim
		decoder_layers.append(nn.Linear(prev_dim, input_dim))
		self.decoder = nn.Sequential(*decoder_layers)

		self.dropout = nn.Dropout(dropout)

		self.init_weights()

	def init_weights(self):
		def init_layer(layer):
			if isinstance(layer, nn.Linear):
				nn.init.xavier_uniform_(layer.weight)
				if layer.bias is not None:
					nn.init.normal_(layer.bias, std=0.01)

		for layer in self.encoder:
			init_layer(layer)
		init_layer(self.distribution_parameters)
		for layer in self.decoder:
			init_layer(layer)

	def encode(self, x: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
		x = F.normalize(x)
		x = self.dropout(x)
		x = self.encoder(x)
		distribution_parameters = self.distribution_parameters(x)
		return distribution_parameters[:, self.latent_dim:], distribution_parameters[:, :self.latent_dim]  # [mu, log_var]

	def reparameterize(self, mu, log_var):
		if self.training:
			std = torch.exp(0.5 * log_var)
			eps = torch.randn_like(std)
			return mu + eps * std
		else:
			return mu

	def decode(self, z: torch.tensor) -> torch.tensor:
		return self.decoder(z)

	def forward(self, input: torch.tensor) -> tuple[torch.tensor, torch.tensor, torch.tensor]:
		mu, log_var = self.encode(input)
		z = self.reparameterize(mu, log_var)
		return self.decode(z), mu, log_var

In [3]:
def multinomial_loss(recon_x, x, mu, log_var, beta_anneal) -> torch.tensor:
	neg_ll = -torch.mean(torch.sum(x * F.log_softmax(recon_x, dim=1), dim=-1))
	kld = -0.5 * torch.mean(torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))
	return neg_ll + beta_anneal * kld

In [4]:
def get_anneal_rate(training_iteration, total_training_iteration, beta_cap) -> float:
	return min(beta_cap, training_iteration / total_training_iteration)

In [5]:
class UsersDataloader:
	def __init__(self, urm: sp.csr_matrix, batch_size: int = 512, shuffle: bool = True):
		self.urm = urm
		self.num_users = self.urm.shape[0]
		self.users_idx = torch.arange(self.urm.shape[0]).long()

		self.batch_size = batch_size
		self.curr_batch_idx = 0

		self.length = self.num_users // self.batch_size + 1

		if shuffle:
			self.users_idx = self.users_idx[torch.randperm(self.num_users)]

	def __iter__(self):
		self.curr_batch_idx = 0
		return self

	def __next__(self):
		if self.curr_batch_idx >= self.length:
			raise StopIteration

		self.curr_batch_idx += 1
		return torch.from_numpy(self.urm[
			self.users_idx[(self.curr_batch_idx - 1) * self.batch_size:self.curr_batch_idx * self.batch_size]
		].toarray())

	def __len__(self):
		return self.length

In [6]:
class MultVAEPR(RecommenderModel):
	def __init__(self):
		super(MultVAEPR, self).__init__()
		self.vae: VAE | None = None
		self.optimizer: Optimizer | None = None
		self.beta_cap: float = 0
		self.loss_fn = None
		self.best_map: float = 0

	def fit(self, urm: sp.csr_matrix, urm_val: sp.csr_matrix, progress_bar: bool = True, hidden_dims: list[int] = None, latent_dim: int = 64, lr: float = 1e-3, beta_cap: float = .4, dropout: float = .5, weight_decay: float = 1e-8, batch_size: int = 512, epochs: int = 50, plot_loss: bool = True, **kwargs) -> None:
		self.urm = urm
		self.beta_cap = beta_cap

		self.vae = VAE(
			input_dim=urm.shape[1],
			hidden_dims=hidden_dims,
			latent_dim=latent_dim,
			dropout=dropout,
		)
		self.optimizer = Adam(self.vae.parameters(), lr=lr, weight_decay=weight_decay)
		self.loss_fn = multinomial_loss

		dataloader = UsersDataloader(self.urm, batch_size=batch_size)
		dataloader_val = UsersDataloader(urm_val, batch_size=batch_size)
		dl_len = len(dataloader)
		total_training_iterations = epochs * dl_len

		loss_history = np.zeros((epochs * dl_len))
		loss_history_val = np.zeros((epochs + 1))
		map_history = np.zeros((epochs + 1))

		validation_enabled = urm_val.nnz > 0
		if validation_enabled:
			self._compute_full_urm_pred()
			map_history[0], loss_history_val[0] = self._validate(dataloader_val, urm_val)

		iterator = (t := trange(epochs, desc="Training...")) if progress_bar else range(epochs)
		for epoch in iterator:
			self.vae.train()
			for batch_idx, users_batch in enumerate(dataloader):
				recon_x, mu, log_var = self.vae(users_batch)

				anneal = get_anneal_rate(epoch * batch_size + batch_idx, total_training_iterations, self.beta_cap)
				loss = self.loss_fn(recon_x, users_batch, mu, log_var, anneal)

				neg_ll = -torch.mean(torch.sum(users_batch * F.log_softmax(recon_x, dim=1), dim=-1))
				kld = -0.5 * torch.mean(torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))

				self.optimizer.zero_grad()
				loss.backward()
				self.optimizer.step()

				if progress_bar:
					with torch.no_grad():
						t.set_postfix({
							"neg_ll": f"{neg_ll.item():.2f}",
							"kld": f"{kld.item():.2f}",
							"Batch": f"{(batch_idx + 1) / dl_len * 100:.2f}%",
							"Train loss": f"{loss.item():.5f}",
							"Val loss": f"{loss_history_val[epoch]:.5f}",
							"MAP@10": f"{map_history[epoch]:.5f}",
							"Best MAP@10": f"{self.best_map:.5f}",
						})
			if validation_enabled:
				self._compute_full_urm_pred()
				map_history[epoch + 1], loss_history_val[epoch + 1] = self._validate(dataloader_val, urm_val)
				self.best_map = max(self.best_map, map_history[epoch + 1])
		if not validation_enabled:
			self._compute_full_urm_pred()
		if plot_loss:
			plot_losses(epochs, loss_history, loss_history_val, len(dataloader), ('MAP@10', [x * len(dataloader) for x in range(epochs + 1)], map_history))


	@torch.no_grad()
	def _compute_full_urm_pred(self, batch_size: int = 4096) -> None:
		del self.urm_pred
		self.urm_pred = torch.zeros(self.urm.shape, dtype=torch.float32)
		dataloader = UsersDataloader(self.urm, batch_size=batch_size, shuffle=False)
		for batch_idx, users_batch in enumerate(dataloader):
			recon_users_batch, _, _ = self.vae(users_batch)
			self.urm_pred[batch_idx * batch_size:(batch_idx + 1) * batch_size] = F.softmax(recon_users_batch, dim=-1)

		self.urm_pred = self.urm_pred.cpu().numpy()

	@torch.no_grad()
	def _validate(self, dataloader_val, urm_val):
		self.vae.eval()
		loss = 0
		for users_batch in dataloader_val:
			recon_users_batch, mu, log_var = self.vae(users_batch)
			loss += self.loss_fn(recon_users_batch, users_batch, mu, log_var, self.beta_cap)
		return evaluate_model(self, urm_val, users_to_test=.2), (loss / len(dataloader_val)).item()

In [7]:
_, _ = train_model(MultVAEPR(), test_size=.2, epochs=20, hidden_dims=[1024], lr=1e-3, latent_dim=256, weight_decay=0, dropout=.5, batch_size=512, beta_cap=.2,)

Training...:  10%|█         | 2/20 [09:51<1:28:45, 295.85s/it, neg_ll=323.81, kld=31.11, Batch=100.00%, Train loss=330.03296, Val loss=88.08788, MAP@10=0.01251, Best MAP@10=0.01251] 


KeyboardInterrupt: 