In [None]:
# default_exp trainers.trainer

# Trainer
> Implementation of PyTorch model trainer.

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.nb_imports import *
from fastcore.test import *

In [None]:
# dataset
from recohut.datasets.movielens import ML1mRatingDataset

# models
from recohut.models.afm import AFM
from recohut.models.afn import AFN
from recohut.models.autoint import AutoInt
from recohut.models.dcn import DCN
from recohut.models.deepfm import DeepFM
from recohut.models.ffm import FFM
from recohut.models.fm import FM
from recohut.models.fnfm import FNFM
from recohut.models.fnn import FNN
from recohut.models.hofm import HOFM
from recohut.models.lr import LR
from recohut.models.ncf import NCF
from recohut.models.nfm import NFM
from recohut.models.ncf import NCF
from recohut.models.pnn import PNN
from recohut.models.wide_and_deep import WideAndDeep
from recohut.models.xdeepfm import xDeepFM

In [None]:
ds = ML1mRatingDataset(root='/content/ML1m', min_uc=10, min_sc=5)

Downloading http://files.grouplens.org/datasets/movielens/ml-1m.zip
Extracting /content/ML1m/raw/ml-1m.zip
Processing...
Done!


In [None]:
import torch
import os
import tqdm
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader

In [None]:
class Args:
    def __init__(self,
                 dataset='ml_1m',
                 model='wide_and_deep'
                 ):
        self.dataset = dataset
        self.model = model
        # dataset
        if dataset == 'ml_1m':
            self.dataset_root = '/content/ML1m'
            self.min_uc = 20
            self.min_sc = 20

        # model training
        self.device = 'cpu' # 'cuda:0'
        self.num_workers = 2
        self.batch_size = 256
        self.lr = 0.001
        self.weight_decay = 1e-6
        self.save_dir = '/content/chkpt'
        self.n_epochs = 2
        self.dropout = 0.2
        self.log_interval = 100

        # model architecture
        if model == 'wide_and_deep':
            self.embed_dim = 16
            self.mlp_dims = (16, 16)
        elif model == 'fm':
            self.embed_dim = 16
        elif model == 'ffm':
            self.embed_dim = 4
        elif model == 'hofm':
            self.embed_dim = 16
            self.order = 3
        elif model == 'fnn':
            self.embed_dim = 16
            self.mlp_dims = (16, 16)
        elif model == 'ipnn':
            self.embed_dim = 16
            self.mlp_dims = (16,)
            self.method = 'inner'
        elif model == 'opnn':
            self.embed_dim = 16
            self.mlp_dims = (16,)
            self.method = 'outer'
        elif model == 'dcn':
            self.embed_dim = 16
            self.num_layers = 3
            self.mlp_dims = (16, 16)
        elif model == 'nfm':
            self.embed_dim = 64
            self.mlp_dims = (64,)
            self.dropouts = (0.2, 0.2)
        elif model == 'ncf':
            self.embed_dim = 16
            self.mlp_dims = (16, 16)
        elif model == 'fnfm':
            self.embed_dim = 4
            self.mlp_dims = (64,)
            self.dropouts = (0.2, 0.2)
        elif model == 'deep_fm':
            self.embed_dim = 16
            self.mlp_dims = (16, 16)
        elif model == 'xdeep_fm':
            self.embed_dim = 16
            self.cross_layer_sizes = (16, 16)
            self.split_half = False
            self.mlp_dims = (16, 16)
        elif model == 'afm':
            self.embed_dim = 16
            self.attn_size = 16
            self.dropouts = (0.2, 0.2)
        elif model == 'autoint':
            self.embed_dim = 16
            self.atten_embed_dim = 64
            self.num_heads = 2
            self.num_layers = 3
            self.mlp_dims = (400, 400)
            self.dropouts = (0, 0, 0)
        elif model == 'afn':
            self.embed_dim = 16
            self.LNN_dim = 1500
            self.mlp_dims = (400, 400, 400)
            self.dropouts = (0, 0, 0)

    def get_dataset(self):
        if self.dataset == 'ml_1m':
            return ML1mRatingDataset(root = self.dataset_root,
                                     min_uc = self.min_uc,
                                     min_sc = self.min_sc
                                     )
    
    def get_model(self, field_dims, user_field_idx=None, item_field_idx=None):
        if self.model == 'wide_and_deep':
            return WideAndDeep(field_dims,
                               embed_dim=self.embed_dim,
                               mlp_dims = self.mlp_dims,
                               dropout = self.dropout
                               )
        elif self.model == 'fm':
            return FM(field_dims,
                      embed_dim = self.embed_dim
                      )
        elif self.model == 'lr':
            return LR(field_dims
                      )
        elif self.model == 'ffm':
            return FFM(field_dims,
                       embed_dim = self.embed_dim
                      )
        elif self.model == 'hofm':
            return HOFM(field_dims,
                        embed_dim = self.embed_dim,
                        order = self.order
                      )
        elif self.model == 'fnn':
            return FNN(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropout = self.dropout
                      )
        elif self.model == 'ipnn':
            return PNN(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       method = self.method,
                       dropout = self.dropout
                      )
        elif self.model == 'opnn':
            return PNN(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       method = self.method,
                       dropout = self.dropout
                      )
        elif self.model == 'dcn':
            return DCN(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       num_layers = self.num_layers,
                       dropout = self.dropout,
                      )
        elif self.model == 'nfm':
            return NFM(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropouts = self.dropouts,
                      )
        elif self.model == 'ncf':
            return NCF(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropout = self.dropout,
                       user_field_idx=user_field_idx,
                       item_field_idx=item_field_idx
                      )
        elif self.model == 'fnfm':
            return FNFM(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropouts = self.dropouts,
                      )
        elif self.model == 'deep_fm':
            return DeepFM(field_dims,
                          embed_dim = self.embed_dim,
                          mlp_dims = self.mlp_dims,
                          dropout = self.dropout,
                      )
        elif self.model == 'xdeep_fm':
            return xDeepFM(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropout = self.dropout,
                       cross_layer_sizes = self.cross_layer_sizes,
                       split_half = self.split_half,
                      )
        elif self.model == 'afm':
            return AFM(field_dims,
                       embed_dim = self.embed_dim,
                       dropouts = self.dropouts,
                       attn_size = self.attn_size,
                      )
        elif self.model == 'autoint':
            return AutoInt(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropouts = self.dropouts,
                       atten_embed_dim = self.atten_embed_dim,
                       num_heads = self.num_heads,
                       num_layers = self.num_layers,
                      )
        elif self.model == 'afn':
            return AFN(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropouts = self.dropouts,
                       LNN_dim = self.LNN_dim,
                      )

In [None]:
class EarlyStopper(object):

    def __init__(self, num_trials, save_path):
        self.num_trials = num_trials
        self.trial_counter = 0
        self.best_accuracy = 0
        self.save_path = save_path

    def is_continuable(self, model, accuracy):
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            self.trial_counter = 0
            torch.save(model, self.save_path)
            return True
        elif self.trial_counter + 1 < self.num_trials:
            self.trial_counter += 1
            return True
        else:
            return False

In [None]:
class Trainer:
    def __init__(self, args):
        device = torch.device(args.device)
        # dataset
        dataset = args.get_dataset()
        # model
        model = args.get_model(dataset.field_dims,
                               user_field_idx = dataset.user_field_idx,
                               item_field_idx = dataset.item_field_idx)
        model = model.to(device)
        model_name = type(model).__name__
        # data split
        train_length = int(len(dataset) * 0.8)
        valid_length = int(len(dataset) * 0.1)
        test_length = len(dataset) - train_length - valid_length
        # data loader
        train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
            dataset, (train_length, valid_length, test_length))
        train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
        valid_data_loader = DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
        test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
        # handlers
        criterion = torch.nn.BCELoss()
        optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        os.makedirs(args.save_dir, exist_ok=True)
        early_stopper = EarlyStopper(num_trials=2, save_path=f'{args.save_dir}/{model_name}.pt')
        # training
        for epoch_i in range(args.n_epochs):
            self._train(model, optimizer, train_data_loader, criterion, device)
            auc = self._test(model, valid_data_loader, device)
            print('epoch:', epoch_i, 'validation: auc:', auc)
            if not early_stopper.is_continuable(model, auc):
                print(f'validation: best auc: {early_stopper.best_accuracy}')
                break
        auc = self._test(model, test_data_loader, device)
        print(f'test auc: {auc}')

    @staticmethod
    def _train(model, optimizer, data_loader, criterion, device, log_interval=100):
        model.train()
        total_loss = 0
        tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
        for i, (fields, target) in enumerate(tk0):
            fields, target = fields.to(device), target.to(device)
            y = model(fields)
            loss = criterion(y, target.float())
            model.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if (i + 1) % log_interval == 0:
                tk0.set_postfix(loss=total_loss / log_interval)
                total_loss = 0
    
    @staticmethod
    def _test(model, data_loader, device):
        model.eval()
        targets, predicts = list(), list()
        with torch.no_grad():
            for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
                fields, target = fields.to(device), target.to(device)
                y = model(fields)
                targets.extend(target.tolist())
                predicts.extend(y.tolist())
        return roc_auc_score(targets, predicts)

In [None]:
models = [
          'wide_and_deep',
          'fm',
          'lr',
          'ffm',
          'hofm',
          'fnn',
          'ipnn',
          'opnn',
          'dcn',
          'nfm',
          'ncf',
          'fnfm',
          'deep_fm',
          'xdeep_fm',
          'afm',
        #   'autoint',
        #   'afn'
          ]

for model in models:
    args = Args(model=model)
    trainer = Trainer(args)

Processing...
Done!
100%|██████████| 3126/3126 [00:23<00:00, 135.91it/s, loss=0.57]
100%|██████████| 391/391 [00:01<00:00, 252.62it/s]


epoch: 0 validation: auc: 0.7781601005135064


100%|██████████| 3126/3126 [00:22<00:00, 137.03it/s, loss=0.557]
100%|██████████| 391/391 [00:01<00:00, 259.00it/s]


epoch: 1 validation: auc: 0.7842454181872189


100%|██████████| 391/391 [00:01<00:00, 261.87it/s]
Processing...


test auc: 0.783773499847308


Done!
100%|██████████| 3126/3126 [00:15<00:00, 200.85it/s, loss=0.587]
100%|██████████| 391/391 [00:01<00:00, 277.38it/s]


epoch: 0 validation: auc: 0.7511323978391329


100%|██████████| 3126/3126 [00:15<00:00, 203.24it/s, loss=0.542]
100%|██████████| 391/391 [00:01<00:00, 286.19it/s]


epoch: 1 validation: auc: 0.7852232398637453


100%|██████████| 391/391 [00:01<00:00, 286.62it/s]
Processing...


test auc: 0.7851983970544512


Done!
100%|██████████| 3126/3126 [00:12<00:00, 243.01it/s, loss=0.713]
100%|██████████| 391/391 [00:01<00:00, 290.07it/s]


epoch: 0 validation: auc: 0.606845663039941


100%|██████████| 3126/3126 [00:12<00:00, 243.68it/s, loss=0.625]
100%|██████████| 391/391 [00:01<00:00, 290.60it/s]


epoch: 1 validation: auc: 0.6962495583229628


100%|██████████| 391/391 [00:01<00:00, 280.21it/s]
Processing...


test auc: 0.6917994954031111


Done!
100%|██████████| 3126/3126 [00:16<00:00, 189.89it/s, loss=0.639]
100%|██████████| 391/391 [00:01<00:00, 275.44it/s]


epoch: 0 validation: auc: 0.6956660360854087


100%|██████████| 3126/3126 [00:16<00:00, 190.43it/s, loss=0.559]
100%|██████████| 391/391 [00:01<00:00, 279.32it/s]


epoch: 1 validation: auc: 0.769259926201433


100%|██████████| 391/391 [00:01<00:00, 275.84it/s]
Processing...


test auc: 0.7694256825177728


Done!
100%|██████████| 3126/3126 [00:23<00:00, 135.66it/s, loss=0.585]
100%|██████████| 391/391 [00:01<00:00, 229.47it/s]


epoch: 0 validation: auc: 0.7508361070441243


100%|██████████| 3126/3126 [00:23<00:00, 135.52it/s, loss=0.538]
100%|██████████| 391/391 [00:01<00:00, 229.00it/s]


epoch: 1 validation: auc: 0.7867336507526798


100%|██████████| 391/391 [00:01<00:00, 226.59it/s]
Processing...


test auc: 0.7849653473859624


Done!
100%|██████████| 3126/3126 [00:21<00:00, 146.97it/s, loss=0.554]
100%|██████████| 391/391 [00:01<00:00, 272.83it/s]


epoch: 0 validation: auc: 0.7899586086314532


100%|██████████| 3126/3126 [00:21<00:00, 148.00it/s, loss=0.544]
100%|██████████| 391/391 [00:01<00:00, 268.07it/s]


epoch: 1 validation: auc: 0.7938707366151592


100%|██████████| 391/391 [00:01<00:00, 267.69it/s]
Processing...


test auc: 0.7935777015287597


Done!
100%|██████████| 3126/3126 [00:20<00:00, 151.01it/s, loss=0.55]
100%|██████████| 391/391 [00:01<00:00, 253.01it/s]


epoch: 0 validation: auc: 0.7901787607777198


100%|██████████| 3126/3126 [00:20<00:00, 151.50it/s, loss=0.536]
100%|██████████| 391/391 [00:01<00:00, 258.53it/s]


epoch: 1 validation: auc: 0.7958062417181883


100%|██████████| 391/391 [00:01<00:00, 265.78it/s]
Processing...


test auc: 0.7959379435427811


Done!
100%|██████████| 3126/3126 [00:21<00:00, 144.98it/s, loss=0.548]
100%|██████████| 391/391 [00:01<00:00, 256.43it/s]


epoch: 0 validation: auc: 0.7943316704845618


100%|██████████| 3126/3126 [00:21<00:00, 145.26it/s, loss=0.53]
100%|██████████| 391/391 [00:01<00:00, 252.76it/s]


epoch: 1 validation: auc: 0.8027591784990165


100%|██████████| 391/391 [00:01<00:00, 259.79it/s]
Processing...


test auc: 0.8016146552653354


Done!
100%|██████████| 3126/3126 [00:26<00:00, 116.37it/s, loss=0.537]
100%|██████████| 391/391 [00:01<00:00, 240.12it/s]


epoch: 0 validation: auc: 0.7898151214837668


100%|██████████| 3126/3126 [00:26<00:00, 116.92it/s, loss=0.527]
100%|██████████| 391/391 [00:01<00:00, 239.57it/s]


epoch: 1 validation: auc: 0.7955138244674892


100%|██████████| 391/391 [00:01<00:00, 240.84it/s]
Processing...


test auc: 0.7964998271099959


Done!
100%|██████████| 3126/3126 [00:22<00:00, 138.66it/s, loss=0.586]
100%|██████████| 391/391 [00:01<00:00, 252.66it/s]


epoch: 0 validation: auc: 0.7631548463451637


100%|██████████| 3126/3126 [00:22<00:00, 137.08it/s, loss=0.551]
100%|██████████| 391/391 [00:01<00:00, 251.84it/s]


epoch: 1 validation: auc: 0.7752154803420491


100%|██████████| 391/391 [00:01<00:00, 252.42it/s]
Processing...


test auc: 0.7727792981788815


Done!
100%|██████████| 3126/3126 [00:23<00:00, 132.24it/s, loss=0.554]
100%|██████████| 391/391 [00:01<00:00, 248.61it/s]


epoch: 0 validation: auc: 0.7876433331502086


100%|██████████| 3126/3126 [00:23<00:00, 132.00it/s, loss=0.543]
100%|██████████| 391/391 [00:01<00:00, 249.84it/s]


epoch: 1 validation: auc: 0.7923030405914255


100%|██████████| 391/391 [00:01<00:00, 257.83it/s]
Processing...


test auc: 0.7930787548185895


Done!
100%|██████████| 3126/3126 [00:23<00:00, 133.99it/s, loss=0.61]
100%|██████████| 391/391 [00:01<00:00, 250.20it/s]


epoch: 0 validation: auc: 0.7376150945998978


100%|██████████| 3126/3126 [00:23<00:00, 135.18it/s, loss=0.583]
100%|██████████| 391/391 [00:01<00:00, 246.25it/s]


epoch: 1 validation: auc: 0.7583206065924306


100%|██████████| 391/391 [00:01<00:00, 245.77it/s]
Processing...


test auc: 0.7594084947700983


Done!
100%|██████████| 3126/3126 [00:24<00:00, 127.45it/s, loss=0.569]
100%|██████████| 391/391 [00:01<00:00, 244.49it/s]


epoch: 0 validation: auc: 0.7806048647711028


100%|██████████| 3126/3126 [00:24<00:00, 128.70it/s, loss=0.554]
100%|██████████| 391/391 [00:01<00:00, 246.00it/s]


epoch: 1 validation: auc: 0.7857091265544482


100%|██████████| 391/391 [00:01<00:00, 245.62it/s]
Processing...


test auc: 0.7857843263334994


Done!
100%|██████████| 3126/3126 [00:30<00:00, 103.68it/s, loss=0.558]
100%|██████████| 391/391 [00:01<00:00, 219.06it/s]


epoch: 0 validation: auc: 0.7814674364890849


100%|██████████| 3126/3126 [00:29<00:00, 104.41it/s, loss=0.539]
100%|██████████| 391/391 [00:01<00:00, 214.94it/s]


epoch: 1 validation: auc: 0.7899837530572655


100%|██████████| 391/391 [00:01<00:00, 216.12it/s]
Processing...


test auc: 0.7863345464272122


Done!
100%|██████████| 3126/3126 [00:23<00:00, 133.32it/s, loss=0.606]
100%|██████████| 391/391 [00:01<00:00, 244.59it/s]


epoch: 0 validation: auc: 0.7590887701790624


100%|██████████| 3126/3126 [00:23<00:00, 134.06it/s, loss=0.576]
100%|██████████| 391/391 [00:01<00:00, 247.06it/s]


epoch: 1 validation: auc: 0.7820711568875622


100%|██████████| 391/391 [00:01<00:00, 247.66it/s]

test auc: 0.7835448236219698





In [None]:
models = [
          'autoint',
          'afn'
          ]

for model in models:
    args = Args(model=model)
    trainer = Trainer(args)

Processing...
Done!
100%|██████████| 3126/3126 [00:43<00:00, 72.44it/s, loss=0.551]
100%|██████████| 391/391 [00:02<00:00, 171.82it/s]


epoch: 0 validation: auc: 0.7838440329134869


100%|██████████| 3126/3126 [00:42<00:00, 73.24it/s, loss=0.532]
100%|██████████| 391/391 [00:02<00:00, 172.49it/s]


epoch: 1 validation: auc: 0.7924653055551055


100%|██████████| 391/391 [00:02<00:00, 169.07it/s]
Processing...


test auc: 0.7935854845577758


Done!
100%|██████████| 3126/3126 [01:24<00:00, 37.12it/s, loss=0.564]
100%|██████████| 391/391 [00:03<00:00, 107.76it/s]


epoch: 0 validation: auc: 0.7796980126749351


100%|██████████| 3126/3126 [01:23<00:00, 37.50it/s, loss=0.547]
100%|██████████| 391/391 [00:03<00:00, 108.16it/s]


epoch: 1 validation: auc: 0.7879478169612124


100%|██████████| 391/391 [00:03<00:00, 108.09it/s]

test auc: 0.7893059350190452





In [None]:
!tree --du -h -C /content/chkpt

[01;34m/content/chkpt[00m
├── [669K]  AFM.pt
├── [ 39M]  AFN.pt
├── [1.5M]  AutoInt.pt
├── [640K]  DCN.pt
├── [676K]  DeepFM.pt
├── [355K]  FFM.pt
├── [666K]  FM.pt
├── [363K]  FNFM.pt
├── [636K]  FNN.pt
├── [1.3M]  HOFM.pt
├── [ 41K]  LR.pt
├── [636K]  NCF.pt
├── [2.5M]  NFM.pt
├── [1.2M]  PNN.pt
├── [676K]  WideAndDeep.pt
└── [682K]  xDeepFM.pt

  51M used in 0 directories, 16 files
