In [1]:
import logging
import sys
import json
import torch
from argparse import ArgumentParser
import pytorch_lightning as pl
import numpy as np
from data import MovieLensDataLoader
from model import GRU4Rec, MLMRecSys
from transformers import AutoTokenizer
from tqdm import tqdm
from utils import compute_metrics

In [2]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
model = MLMRecSys.load_from_checkpoint('wandb/run-20210722_070252-21se3gnd/files/LMRecSys/21se3gnd/checkpoints/epoch=7-step=5342.ckpt').to('cuda:0').eval()
id2name = json.load(open('datasets/MovieLens-1M-5Star/id2name.json'))

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
def convert_to_features(ids, id2name, tokenizer, n_mask=10, max_seq_length=128, input_str=None, to_tensor=True, device='cuda:0'):

    def pattern(ids, id2name, mask_token, n_mask):
        s = 'A user watched '
        for id in ids:
            s += id2name[id] + ', '
        s = s.strip()[:-1] + '. '
        s += 'Now the user may want to watch'
        s += mask_token * n_mask
        return s
    
    assert(ids is None or input_str is None)

    if input_str is None:
        input_str = pattern(ids, id2name, tokenizer.mask_token, n_mask)

    result = tokenizer(
        input_str,
        add_special_tokens=True,
        max_length=max_seq_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_token_type_ids=True
    )
    result['input_str'] = input_str
    result['mask_idxs'] = [i for i, id in enumerate(result['input_ids']) if id == tokenizer.mask_token_id]

    if to_tensor: 
        for key in ['input_ids', 'token_type_ids', 'attention_mask', 'mask_idxs']:
            result[key] = torch.tensor([result[key]]).to(device)

    return result

In [4]:
# Reproduce the numbers
data = [json.loads(line) for line in open('datasets/MovieLens-1M-5Star/data.jsonl')]
all_labels, all_top_preds = [], []
for item in tqdm(data):
    ids = item['ids'][:5]
    inputs = convert_to_features(ids, id2name, tokenizer, input_str=None)
    outputs = model(inputs)
    logits = outputs['label_logits_aggregated']
    top_preds = logits.argsort(-1)
    all_top_preds.append(top_preds[0].tolist())
    all_labels.append(item['ids'][5])
compute_metrics(np.array(all_top_preds), np.array(all_labels), prefix='val')

100%|██████████| 5337/5337 [01:37<00:00, 54.82it/s]


{'val/r@20': 0.12235338204984074,
 'val/mrr@20': 0.02702487683190047,
 'val/ndcg@20': 0.05358626384134088}

In [4]:
# Test input ids
data = [json.loads(line) for line in open('datasets/MovieLens-1M-5Star/data.jsonl')]

item = data[0]
ids = item['ids'][:5]
inputs = convert_to_features(ids, id2name, tokenizer, input_str=None)
outputs = model(inputs)
logits = outputs['label_logits_aggregated']
top_preds = logits.argsort(-1)

print(inputs['input_str'])
print([id2name[id] for id in top_preds[0][:5]])

A user watched Ben-Hur, Dumbo, Schindler's List, Beauty and the Beast, Toy Story. Now the user may want to watch[MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK]
['Good Will Hunting', 'Beauty and the Beast', 'Toy Story', 'Braveheart', 'The Lion King']


In [5]:
# Test input strings
input_str = input()
inputs = convert_to_features(None, id2name, tokenizer, input_str=input_str)
outputs = model(inputs)
logits = outputs['label_logits_aggregated']
top_preds = logits.argsort(-1)

print(inputs['input_str'])
display([(id2name[id], logits[0][id].item()) for id in top_preds[0][:5]])

Now the user may want to watch[MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK]


[('The Sixth Sense', 44.42577362060547),
 ('Pulp Fiction', 44.58979034423828),
 ('The Shawshank Redemption', 44.60279083251953),
 ('Fight Club', 44.667724609375),
 ('The Princess Bride', 44.74372100830078)]

In [56]:
# MLM
input_str = 'It is a [MASK] day.'

# model = MLMRecSys().to('cuda:0').eval()
inputs = tokenizer(input_str, return_tensors='pt')
mask_idx = [i for i, token_id in enumerate(inputs['input_ids'][0]) if token_id == tokenizer.mask_token_id]
assert(len(mask_idx) == 1)
mask_idx = mask_idx[0]
for key in inputs: inputs[key] = inputs[key].to('cuda:0')
print(inputs, mask_idx)
top_preds = model.model(**inputs).logits[0][mask_idx].argsort().tolist()[::-1][:10]
tokenizer.decode(top_preds)

{'input_ids': tensor([[ 101, 1135, 1110,  170,  103, 1285,  119,  102]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')} 4


'. " by long the good ; - October new'

# AmazonPantry

In [6]:
id2name = json.load(open('datasets/AmazonPantry/id2name.json'))
model.hparams.data_dir = 'datasets/AmazonPantry/'
model.create_verbalizers()

In [52]:
# Reproduce the numbers
data = [json.loads(line) for line in open('datasets/AmazonPantry/data.jsonl')]
all_labels, all_top_preds = [], []
for item in tqdm(data[:2000]):
    if len(item['ids']) < 7: continue
    ids = item['ids'][:5]
    inputs = convert_to_features(ids, id2name, tokenizer, input_str=None, max_seq_length=512)
    outputs = model(inputs)
    logits = outputs['label_logits_aggregated']
    top_preds = logits.argsort(-1)
    all_top_preds.append(top_preds[0].tolist())
    all_labels.append(item['ids'][5])
compute_metrics(np.array(all_top_preds), np.array(all_labels), prefix='val')

100%|██████████| 2000/2000 [00:17<00:00, 112.15it/s]


{'val/r@20': 0.0047562425683709865,
 'val/mrr@20': 0.0019157088122605365,
 'val/ndcg@20': 0.0027743726963040762}

In [15]:
# Test input ids
data = [json.loads(line) for line in open('datasets/AmazonPantry/data.jsonl')]

item = data[9]
print(item['ids'])
ids = item['ids'][:5]
inputs = convert_to_features(ids, id2name, tokenizer, input_str=None, max_seq_length=256)
outputs = model(inputs)
logits = outputs['label_logits_aggregated']
top_preds = logits.argsort(-1)

print(inputs['input_str'], ids, item['ids'][5], id2name[item['ids'][5]])
print([id2name[id] for id in top_preds[0][:5]])

[2321, 2365, 2371, 2357, 85, 426, 426, 1482, 64, 451, 451, 3359, 2055, 754, 445, 445, 3018, 397, 397, 3047, 529, 1026, 3029, 639, 1, 1094, 1490, 2040, 2710, 151, 2285, 1410, 891]
A user watched King Arthur Flour 100% Organic Unbleached All-Purpose Flour, 80 Ounce., Nestle Toll House DelightFulls Milk Chocolate Morsels with Caramel Filling, 9 Ounce., Nestle Toll House DelightFulls Dark Chocolate Morsels with Mint Filling, 9 Ounce., Nestle Toll House Semi-Sweet Chocolate Morsels, 24 Ounce., Lundberg Family Farms Organic Basmati Rice, California Brown, 32 Ounce.. Now the user may want to watch[MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK] [2321, 2365, 2371, 2357, 85] 426 Huy Fong Hot Chili Sauce, Sriracha, 28 oz.
['Crazy Cups Coffee, Hot Chocolate and Irish Creme Cheesecake, 22 Count.', 'Best Foods Creamy Real Mayonnaise, Gluten Free, Kosher, 64 oz.', 'Crazy Cups Coffee, Hot Chocolate and Salted Caramel, 22 Count.', 'Crazy Cups Coffee, Tea and Hot Chocolate Variety Sampler P