In [None]:
# default_exp trainers.pl_trainer

In [None]:
#hide
!pip install pytorch-lightning
%cd /content
!rm -rf recohut
!git clone --branch US632593 https://github.com/RecoHut-Projects/recohut.git
%cd recohut
!pip install -U .
!apt-get -qq install tree
!pip install -q watermark

# PL Trainer
> Implementation of trainer for training PyTorch Lightning models.

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.nb_imports import *
from fastcore.test import *

In [None]:
#export
from typing import Any, Iterable, List, Optional, Tuple, Union, Callable
import os

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

In [None]:
#export
def pl_trainer(model, datamodule, max_epochs=10, val_epoch=5, gpus=None, log_dir=None, model_dir=None):
    log_dir = log_dir if log_dir is not None else os.getcwd()
    model_dir = model_dir if model_dir is not None else os.getcwd()

    logger = TensorBoardLogger(save_dir=log_dir)

    checkpoint_callback = ModelCheckpoint(
        monitor="Val Metrics",
        mode="max",
        dirpath=model_dir,
        filename="recommender",
    )

    trainer = Trainer(
    max_epochs=max_epochs,
    logger=logger,
    check_val_every_n_epoch=val_epoch,
    callbacks=[checkpoint_callback],
    num_sanity_val_steps=0,
    gradient_clip_val=1,
    gradient_clip_algorithm="norm",
    gpus=gpus
    )

    trainer.fit(model, datamodule=datamodule)
    test_result = trainer.test(model, datamodule=datamodule)
    return test_result

Example

In [None]:
class Args:
    def __init__(self):
        self.data_dir = '/content/data'
        self.min_rating = 4
        self.num_negative_samples = 99
        self.min_uc = 5
        self.min_sc = 5
        self.val_p = 0.2
        self.test_p = 0.2
        self.num_workers = 2
        self.normalize = False
        self.batch_size = 32
        self.seed = 42
        self.shuffle = True
        self.pin_memory = True
        self.drop_last = False
        self.split_type = 'stratified'

args = Args()

from recohut.datasets.movielens import ML1mDataModule

ds = ML1mDataModule(**args.__dict__)

ds.prepare_data()

Processing...


Turning into implicit ratings
Filtering triplets
Densifying index


Done!


In [None]:
from recohut.models.nmf import NMF

model = NMF(n_items=ds.data.num_items, n_users=ds.data.num_users, embedding_dim=20)

In [None]:
pl_trainer(model, ds, max_epochs=5)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name               | Type      | Params
-------------------------------------------------
0 | user_embedding     | Embedding | 120 K 
1 | item_embedding     | Embedding | 62.5 K
2 | user_embedding_gmf | Embedding | 120 K 
3 | item_embedding_gmf | Embedding | 62.5 K
4 | gmf                | Linear    | 210   
5 | fc1                | Linear    | 820   
6 | fc2                | Linear    | 420   
7 | fc3                | Linear    | 210   
8 | fc_final           | Linear    | 21    
9 | dropout            | Dropout   | 0     
-------------------------------------------------
368 K     Trainable params
0         Non-trainable params
368 K     Total params
1.472     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Test Metrics': {'apak': tensor(0.0752),
                  'hr': tensor(0.2275),
                  'loss': tensor(0.1823),
                  'ncdg': tensor(0.1102)}}
--------------------------------------------------------------------------------


[{'Test Metrics': {'apak': tensor(0.0752),
   'hr': tensor(0.2275),
   'loss': tensor(0.1823),
   'ncdg': tensor(0.1102)}}]

In [None]:
#hide
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d -p recohut

Author: Sparsh A.

Last updated: 2022-01-10 10:43:14

recohut: 0.0.10

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.144+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

IPython: 5.5.0

