In [1]:
import os, sys, gc
import time
import glob
import pickle
import copy
import json
import random
from collections import OrderedDict, namedtuple
import multiprocessing
import threading
import traceback

from typing import Tuple, List

import h5py
from tqdm import tqdm, tqdm_notebook

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import cv2
from PIL import Image


import torch
import torchvision
import torch.nn.functional as F

from torch import nn, optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchmetrics
import pl_bolts
import pytorch_lightning as pl



from IPython.display import display, clear_output

import faiss

from modules.AugsDS_v7 import *
from modules.eval_functions import *
from modules.eval_metrics import evaluate

sys.path.append('./modules')

 - Found: 861 screenshots.  SCREENSHOT_DIR=./FB_page_qry


In [2]:
from modules.Facebook_model_v20 import ArgsT19_EffNetV2S_ImageNet, FacebookModel

In [3]:
args = ArgsT19_EffNetV2S_ImageNet()

args.pretrained_bb = False
print(args) 

 = = = = = = = = = = ArgList = = = = = = = = = =
ALL_FOLDERS                   : ['query_images', 'reference_images', 'training_images', 'imagenet_images']
BATCH_SIZE                    : 64
DATASET_WH                    : (384, 384)
DS_DIR                        : ./all_datasets/dataset_jpg_384x384
DS_INPUT_DIR                  : ./all_datasets/dataset
GeM_opt_p                     : True
GeM_p                         : 3.0
ImgNet_SAMPLES                : ./data/ImageNet_samples_v.pickle
N_GPUS                        : 1
N_WORKERS                     : 8
OUTPUT_WH                     : (224, 224)
accelerator                   : ddp
arc_bottleneck                : None
arc_classnum                  : 1200000
arc_gamma                     : 3.0
arc_m                         : 0.4
arc_s                         : 40.0
backbone_name                 : efficientnetv2_rw_s
checkpoint_base_path          : ./TEST19_ArcF_ImgNet
clip_grad_norm                : 1.0
criterion_name                : 

# Building model

In [4]:
model = FacebookModel(args)

 - Total weights: 483.64M


# Loading ckpt

In [5]:
ckpt_filename = './checkpoints/smp_test19/FacebookModel_Eepoch=67_TLtrn_loss_epoch=0.9913_TAtrn_acc_epoch=0.0000_VLval_loss_epoch=0.4815_VAval_acc_epoch=0.9893.ckpt'
_ = model.restore_checkpoint(ckpt_filename)

 - Restored checkpoint: ./checkpoints/smp_test19/FacebookModel_Eepoch=67_TLtrn_loss_epoch=0.9913_TAtrn_acc_epoch=0.0000_VLval_loss_epoch=0.4815_VAval_acc_epoch=0.9893.ckpt.


# Inference configuration

In [6]:
do_simple_augmentation = False
K = 500

BATCH_SIZE   = 128
N_WORKERS    = 7
DS_INPUT_DIR = f'./all_datasets/dataset'
ALL_FOLDERS  = ['query_images', 'reference_images', 'training_images']

args.ALL_FOLDERS = ALL_FOLDERS
args.BATCH_SIZE = BATCH_SIZE
args.N_WORKERS = N_WORKERS
args.DS_INPUT_DIR = DS_INPUT_DIR

In [7]:
while DS_INPUT_DIR[-1] in ['/', r'\\']:
    DS_INPUT_DIR = DS_INPUT_DIR[:-1]
    
# Path where the rescaled images will be saved
args.DS_DIR = f'{args.DS_INPUT_DIR}_jpg_{args.DATASET_WH[0]}x{args.DATASET_WH[1]}'

# Data Source

In [8]:
if any( [not os.path.exists(os.path.join(args.DS_DIR, folder)) for folder in args.ALL_FOLDERS] ):
    assert os.path.exists(args.DS_INPUT_DIR), f'DS_INPUT_DIR not found: {args.DS_INPUT_DIR}'

    resize_dataset(
        ds_input_dir=args.DS_INPUT_DIR,
        ds_output_dir=args.DS_DIR,
        output_wh=args.DATASET_WH,
        output_ext='jpg',
        num_workers=args.N_WORKERS,
        ALL_FOLDERS=args.ALL_FOLDERS,
        verbose=False,
    )

print('Paths:')
print(' - DS_INPUT_DIR:', args.DS_INPUT_DIR)
print(' - DS_DIR:      ', args.DS_DIR)

assert os.path.exists(args.DS_DIR), f'DS_DIR not found: {args.DS_DIR}'

try:
    public_ground_truth_path = os.path.join(args.DS_DIR, 'public_ground_truth.csv')
    public_gt = pd.read_csv( public_ground_truth_path)

except:
    public_ground_truth_path = os.path.join(args.DS_INPUT_DIR, 'public_ground_truth.csv')
    public_gt = pd.read_csv( public_ground_truth_path)

Paths:
 - DS_INPUT_DIR: ./all_datasets/dataset
 - DS_DIR:       ./all_datasets/dataset_jpg_384x384


# Datasets

In [9]:
ds_qry_full = FacebookDataset(
    samples_id_v=[f'Q{i:05d}' for i in range(50_000)],
    do_augmentation=False,
    ds_dir=args.DS_DIR,
    output_wh=args.OUTPUT_WH,
    channel_first=True,
    norm_type= args.img_norm_type,
    verbose=True,
)
# ds_qry_full.plot_sample(4)


ds_ref_full = FacebookDataset(
    samples_id_v=[f'R{i:06d}' for i in range(1_000_000)],
    do_augmentation=False,
    ds_dir=args.DS_DIR,
    output_wh=args.OUTPUT_WH,
    channel_first=True,
    norm_type=args.img_norm_type,
    verbose=True,
)
# ds_ref_full.plot_sample(4)


ds_trn_full = FacebookDataset(
    samples_id_v=[f'T{i:06d}' for i in range(1_000_000)],
    do_augmentation=False,
    ds_dir=args.DS_DIR,
    output_wh=args.OUTPUT_WH,
    channel_first=True,
    norm_type=args.img_norm_type,
    verbose=True,
)
# ds_trn_full.plot_sample(4)



dl_qry_full = DataLoader(
        ds_qry_full,
        batch_size=args.BATCH_SIZE,
        num_workers=args.N_WORKERS,
        shuffle=False,
    )

dl_ref_full = DataLoader(
    ds_ref_full,
    batch_size=args.BATCH_SIZE,
    num_workers=args.N_WORKERS,
    shuffle=False,
)

dl_trn_full = DataLoader(
    ds_trn_full,
    batch_size=args.BATCH_SIZE,
    num_workers=args.N_WORKERS,
    shuffle=False,
)



### Query embeddings

In [11]:
embed_qry_d = calc_embed_d(
    model, 
    dataloader=dl_qry_full,
    do_simple_augmentation=do_simple_augmentation
)

100%|█████████████████████████████████████████| 391/391 [01:38<00:00,  3.98it/s]


### Reference embeddings

In [20]:
aug = '_AUG' if do_simple_augmentation else ''
submission_path = ckpt_filename.replace('.ckpt', f'_{args.OUTPUT_WH[0]}x{args.OUTPUT_WH[1]}{aug}_REF.h5')
scores_path = submission_path.replace('.h5', '_match_d.pickle')

In [13]:
embed_ref_d = calc_embed_d(
    model, 
    dataloader=dl_ref_full, 
    do_simple_augmentation=do_simple_augmentation
)

save_submission(
    embed_qry_d,
    embed_ref_d,
    save_path=submission_path,
)

match_d = calc_match_scores(embed_qry_d, embed_ref_d, k=K)
save_obj(match_d, scores_path)

100%|███████████████████████████████████████| 7813/7813 [30:34<00:00,  4.26it/s]


 - Saved: ./checkpoints/smp_test19/FacebookModel_Eepoch=67_TLtrn_loss_epoch=0.9913_TAtrn_acc_epoch=0.0000_VLval_loss_epoch=0.4815_VAval_acc_epoch=0.9893_224x224_REF.h5


100%|█████████████████████████████████████████| 100/100 [06:09<00:00,  3.70s/it]

Saved: ./checkpoints/smp_test19/FacebookModel_Eepoch=67_TLtrn_loss_epoch=0.9913_TAtrn_acc_epoch=0.0000_VLval_loss_epoch=0.4815_VAval_acc_epoch=0.9893_224x224_REF_match_d.pickle





### Public GT validation

In [22]:
eval_d = evaluate(
    submission_path=submission_path,
    gt_path=public_ground_truth_path,
    is_matching=False,
)

./checkpoints/smp_test19/FacebookModel_Eepoch=67_TLtrn_loss_epoch=0.9913_TAtrn_acc_epoch=0.0000_VLval_loss_epoch=0.4815_VAval_acc_epoch=0.9893_224x224_REF.h5
{
  "average_precision": 0.6378110962890179,
  "recall_p90": 0.5223402123822881
}


### Training embeddings

In [15]:
aug = '_AUG' if do_simple_augmentation else ''
submission_path = ckpt_filename.replace('.ckpt', f'_{args.OUTPUT_WH[0]}x{args.OUTPUT_WH[1]}{aug}_TRN.h5')
scores_path = submission_path.replace('.h5', '_match_d.pickle')

In [16]:
embed_trn_d = calc_embed_d(
    model, 
    dataloader=dl_trn_full, 
    do_simple_augmentation=do_simple_augmentation
)

save_submission(
    embed_qry_d,
    embed_trn_d,
    save_path=submission_path,
)

100%|███████████████████████████████████████| 7813/7813 [31:35<00:00,  4.12it/s]


 - Saved: ./checkpoints/smp_test19/FacebookModel_Eepoch=67_TLtrn_loss_epoch=0.9913_TAtrn_acc_epoch=0.0000_VLval_loss_epoch=0.4815_VAval_acc_epoch=0.9893_224x224_TRN.h5


  1%|▍                                       | 1/100 [02:13<3:39:41, 133.15s/it]


KeyboardInterrupt: 

In [17]:
match_d = calc_match_scores(embed_qry_d, embed_trn_d, k=K)
save_obj(match_d, scores_path)

100%|█████████████████████████████████████████| 100/100 [06:19<00:00,  3.80s/it]

Saved: ./checkpoints/smp_test19/FacebookModel_Eepoch=67_TLtrn_loss_epoch=0.9913_TAtrn_acc_epoch=0.0000_VLval_loss_epoch=0.4815_VAval_acc_epoch=0.9893_224x224_TRN_match_d.pickle



