# Summary
This notebook demonstrates the use of **Pytorch Lightning** with most utility functions moved to the utility scripts to keep the notebook clean.

Most of the code is based on and inspired by the following notebooks:
* https://www.kaggle.com/tanulsingh077/metric-learning-pipeline-only-text-sbert
* https://www.kaggle.com/underwearfitting/pytorch-densenet-arcface-validation-training/notebook

**Training notebook**:<br>
https://www.kaggle.com/kcostya/ride-the-lightning-training/

In [None]:
import gc
import os

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from tqdm.auto import tqdm

from shopee_datasets import ShopeeDataModule
from shopee_models import ShopeeNet
from shopee_utils import seed_everything

In [None]:
for path, dir, files in os.walk("../input/train-notebook/checkpoints/"):
    ckpt_path = os.path.join(path, files[0])

In [None]:
debug = False  # debug mode requires Internet connection

In [None]:
PATH_TO_CKPT = ckpt_path
TOKENIZER_PATH = "../input/train-notebook/tokenizer/"
TRANSFORMER_PATH = "../input/train-notebook/transformer/"
CSV_TEST = "../input/shopee-product-matching/test.csv"
CSV_SUBMISSION = "../input/shopee-product-matching/sample_submission.csv"
IMAGES_TEST = "../input/shopee-product-matching/test_images"
N_SPLITS = 5
NUM_WORKERS = 4
TEST_BATCH_SIZE = 32
SEED = 23
TOKENIZER_MAX_LEN = 10
N_BATCH = 10
SIM_THRESH = 0.65

if debug:
    CSV_TEST = "../input/shopee-product-matching/train.csv"
    TOKENIZER_PATH = "./"

In [None]:
seed_everything(SEED)

In [None]:
dm = ShopeeDataModule(
        path_to_csv = CSV_TEST,
        path_to_images = IMAGES_TEST,
        n_splits = N_SPLITS,
        random_state = SEED,
        batch_size = TEST_BATCH_SIZE,
        tokenizer_max_len = TOKENIZER_MAX_LEN,
        num_workers = NUM_WORKERS,
        tokenizer_path = TOKENIZER_PATH,
)

if debug:
    dm.setup("train")
else:
    dm.setup("test")

In [None]:
model = ShopeeNet.load_from_checkpoint(
    PATH_TO_CKPT, 
    transformer_path=TRANSFORMER_PATH, 
    test_mode=not debug,
)

In [None]:
if debug:
    dataloader = dm.val_dataloader()
else:
    dataloader = dm.test_dataloader()

trainer = pl.Trainer(
    gpus=1 if torch.cuda.is_available() else None,
)

prediction_list = []

with torch.no_grad():
    for i, batch in enumerate(tqdm(dataloader)):
        if debug:
            a, b, _ = batch
            output = model(a, b)
        else:
            output = model.predict(batch, i)
        prediction_list.append(output.detach().cpu())
    
feats = torch.cat(prediction_list).cpu().numpy()

In [None]:
del model
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache() 

if debug:
    test = dm.data.iloc[:len(feats), :]
else:
    test = dm.data
del dm
gc.collect()

In [None]:
def combine_for_sub(row):
    x = np.concatenate([row["preds_phash"], row["preds_bert"]])
    return " ".join(np.unique(x))

In [None]:
tmp = test.groupby("image_phash").posting_id.agg("unique").to_dict()
test["preds_phash"] = test.image_phash.map(tmp)

In [None]:
n, _ = feats.shape
bs = n // 10

if n != 3:
    if torch.cuda.is_available():
        feats = torch.tensor(feats).cuda()
    else:
        feats = torch.tensor(feats)
    
    batches = []
    for i in range(N_BATCH):
        left = bs * i
        right = bs * (i + 1)
        if i == N_BATCH - 1:
            right = n
        batches.append(feats[left:right, :])

    matches = []
    for batch in tqdm(batches):
        if torch.cuda.is_available():
            batch = batch.cuda()
        selection = ((batch @ feats.T) > SIM_THRESH).cpu().numpy()
        for row in selection:
            matches.append(test.iloc[row]["posting_id"].tolist())
    
    test["preds_bert"] = matches
    test["matches"] = test.apply(combine_for_sub, axis=1)
    
    submission = pd.read_csv(CSV_SUBMISSION)
    submission["matches"] = test["matches"]
    submission.to_csv("submission.csv", index=False)

# don't do anything during commit
else:
    submission = pd.read_csv(CSV_SUBMISSION)
    submission.to_csv("submission.csv", index=False)