# Summary
This notebook demonstrates the use of **Pytorch Lightning** with most utility functions moved to the utility scripts to keep the notebook clean.

Most of the code is based on and inspired by the following notebooks:
* https://www.kaggle.com/tanulsingh077/metric-learning-pipeline-only-text-sbert
* https://www.kaggle.com/underwearfitting/pytorch-densenet-arcface-validation-training/notebook

**Inference notebook**:<br>
https://www.kaggle.com/kcostya/ride-the-lightning-inference/

In [None]:
import math
import warnings

import numpy as np
import pytorch_lightning as pl
import torch
import transformers
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch import nn as nn
from torch.nn import Parameter
from torch.nn import functional as F
from transformers import AdamW, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup

from shopee_datasets import ShopeeDataModule
from shopee_models import ArcMarginProduct, find_threshold, ShopeeNet
from shopee_utils import KerasProgressBar, seed_everything

In [None]:
warnings.filterwarnings("ignore")

### Config

In [None]:
CSV_TRAIN = "../input/shopee-product-matching/train.csv"
IMAGES_TRAIN = "../input/shopee-product-matching/train_images"
N_SPLITS = 5

NUM_WORKERS = 4
TRAIN_BATCH_SIZE = 256
EPOCHS = 30
SEED = 23
LR = 5e-5
TRANSFORMER_MODEL = "sentence-transformers/paraphrase-xlm-r-multilingual-v1"
TOKENIZER_MAX_LEN = 50

# FC layer parameters
USE_FC = True
FC_DIM = 512
DROPOUT = 0.2

# Metric Loss and its params
LOSS_MODULE = "arcface"
S = 30.0
M = 0.3
LS_EPS = 0.01
EASY_MARGIN = False
THETA_ZERO = math.pi / 4

SEARCH_SPACE = np.arange(55, 85, 3)

global schedule
schedule = get_cosine_schedule_with_warmup

In [None]:
seed_everything(SEED)

### Model

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_score",
    dirpath="checkpoints",
    filename="ckpt-epoch{epoch:02d}-val_score{val_score:.2f}",
    save_top_k=1,
    mode="max",
)

early_stop_callback = EarlyStopping(
    monitor="val_score", min_delta=0.00, patience=8, verbose=False, mode="max"
)

### Run!

In [None]:
dm = ShopeeDataModule(
    path_to_csv=CSV_TRAIN,
    path_to_images=IMAGES_TRAIN,
    n_splits=N_SPLITS,
    random_state=SEED,
    batch_size=TRAIN_BATCH_SIZE,
    tokenizer_max_len=TOKENIZER_MAX_LEN,
    num_workers=NUM_WORKERS,
    tokenizer_path="tokenizer"
)
dm.setup("train")

In [None]:
model_params = {
    "n_classes": 11014,
    "valid_df": dm.data.query("fold==0"),
    "model_name": TRANSFORMER_MODEL,
    "use_fc": USE_FC,
    "fc_dim": FC_DIM,
    "dropout": DROPOUT,
    "loss_module": LOSS_MODULE,
    "s": S,
    "margin": M,
    "ls_eps": LS_EPS,
    "easy_margin": EASY_MARGIN,
    "theta_zero": THETA_ZERO,
    "num_warmup_steps": dm.num_batches * 2,
    "num_training_steps": dm.num_batches * EPOCHS,
    "search_space": SEARCH_SPACE
}


model = ShopeeNet(**model_params)
bar = KerasProgressBar()

In [None]:
trainer = pl.Trainer(
    gpus=1 if torch.cuda.is_available() else None, 
    max_epochs=EPOCHS, 
    callbacks=[bar, checkpoint_callback, early_stop_callback],
    gradient_clip_val=0.5,
)
trainer.fit(model, dm)

In [None]:
%rm -rf lightning_logs