[Dataset Ensemble]  
1) Deit Cropped Dataset  
2) YOLO V5 Cropped Dataset

In [1]:
!pip install faiss
# !pip install torchmetrics

Collecting faiss
  Downloading faiss-1.5.3-cp37-cp37m-manylinux1_x86_64.whl (4.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: faiss
Successfully installed faiss-1.5.3
[0m

In [2]:
import os, sys, gc, time, random, warnings, math, cv2
import wandb, optuna, faiss, timm, torch
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import albumentations as albu
from pytorch_lightning.callbacks import EarlyStopping
from torchmetrics.retrieval import RetrievalMAP
from albumentations.pytorch import ToTensorV2
from PIL import Image
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from timm.data.transforms_factory import create_transform
from timm.optim import create_optimizer_v2
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import normalize
from sklearn.preprocessing import LabelEncoder
from kaggle_secrets import UserSecretsClient
from glob import glob
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple
from torch.autograd import Variable
warnings.filterwarnings("ignore")

In [3]:
# WandB Login => Copy API Key
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")

!wandb login $secret_value_0

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [4]:
class CFG:
    checkpoint_dir = './saved/model'
    name = 'HappyWhale'
    model = 'convnext_base_384_in22ft1k'

    """ Common Options """
    wandb = True
    optuna = True  # if you want to tune hyperparameter, set True
    competition = 'HappyWhale'
    seed = 42
    cfg_name = 'CFG'
    n_gpu = 1
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    gpu_id = 0
    num_workers = 0

    """ Data Options """
    n_folds = 5
    epochs = 180
    img_size = 384
    batch_size = 64

    """ Gradient Options """
    amp_scaler = True
    gradient_checkpoint = True  # save parameter
    clipping_grad = True  # clip_grad_norm
    n_gradient_accumulation_steps = 1
    max_grad_norm = 1000

    """ Loss & Metrics Options """
    loss_fn = ''
    reduction = 'mean'
    metrics = ['MCRMSE', 'f_beta', 'recall']

    """ Optimizer with LLRD Options """
    optimizer = 'AdamW'  # options: SWA, AdamW
    llrd = True
    layerwise_lr = 5e-5
    layerwise_lr_decay = 0.9
    layerwise_weight_decay = 1e-2
    layerwise_adam_epsilon = 1e-6
    layerwise_use_bertadam = False
    betas = (0.9, 0.999)

    """ Scheduler Options """
    scheduler = 'cosine_annealing'  # options: cosine, linear, cosine_annealing, linear_annealing
    batch_scheduler = True
    num_cycles = 0.5  # num_warmup_steps = 0
    warmup_ratio = 0.1  # options: 0.05, 0.1

    """ SWA Options """
    swa = True
    swa_start = int(epochs*0.75)
    swa_lr = 1e-4
    anneal_epochs = 4
    anneal_strategy = 'cos'  # default = cos, available option: linear

    """ Model_Utils Options """
    freeze = False
    reinit = True
    num_reinit = 5
    awp = False
    nth_awp_start_epoch = 10
    awp_eps = 1e-2
    awp_lr = 1e-4

In [5]:
""" Helper Function """

def check_device() -> bool:
    return torch.mps.is_available()

def check_library(checker: bool) -> tuple:
    """
    1) checker == True
        - current device is mps
    2) checker == False
        - current device is cuda with cudnn
    """
    if not checker:
        _is_built = torch.backends.cudnn.is_available()
        _is_enable = torch.backends.cudnn.enabledtorch.backends.cudnn.enabled
        version = torch.backends.cudnn.version()
        device = (_is_built, _is_enable, version)
        return device

def class2dict(cfg) -> dict:
    return dict((name, getattr(cfg, name)) for name in dir(cfg) if not name.startswith('__'))


def all_type_seed(cfg, checker: bool) -> None:
    # python & torch seed
    os.environ['PYTHONHASHSEED'] = str(cfg.seed)  # python Seed
    random.seed(cfg.seed)  # random module Seed
    np.random.seed(cfg.seed)  # numpy module Seed
    torch.manual_seed(cfg.seed)  # Pytorch CPU Random Seed Maker

    # device == cuda
    if not checker:
        torch.cuda.manual_seed(cfg.seed)  # Pytorch GPU Random Seed Maker
        torch.cuda.manual_seed_all(cfg.seed)  # Pytorch Multi Core GPU Random Seed Maker
        # torch.cudnn seed
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.enabled = False

def seed_worker(worker_id) -> None:
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    

check_library(True)
all_type_seed(CFG, True)

g = torch.Generator()
g.manual_seed(CFG.seed)

<torch._C.Generator at 0x728e32898590>

In [6]:
""" Path & Settings """
""" For Dataset 1 """

INPUT_DIR = Path("..") / "input"
OUTPUT_DIR = Path("/") / "kaggle" / "working"

DATA_ROOT_DIR = INPUT_DIR / "convert-backfintfrecords" / "happy-whale-and-dolphin-backfin"
TRAIN_DIR = DATA_ROOT_DIR / "train_images"
TEST_DIR = DATA_ROOT_DIR / "test_images"
TRAIN_CSV_PATH = DATA_ROOT_DIR / "train.csv"
SAMPLE_SUBMISSION_CSV_PATH = DATA_ROOT_DIR /"sample_submission.csv"
PUBLIC_SUBMISSION_CSV_PATH = INPUT_DIR / "publicsubmission758" / "submission.csv"
IDS_WITHOUT_BACKFIN_PATH = INPUT_DIR / "ids-without-backfin" / "ids_without_backfin.npy"

N_SPLITS = 5

ENCODER_CLASSES_PATH = OUTPUT_DIR /" encoder_classes.npy"
TEST_CSV_PATH = OUTPUT_DIR / " test.csv"
TRAIN_CSV_ENCODED_FOLDED_PATH = OUTPUT_DIR / "train_encoded_folded.csv"
CHECKPOINTS_DIR = OUTPUT_DIR / "conv-nextday322pl"
SUBMISSION_CSV_PATH = OUTPUT_DIR / "submission.csv"

DEBUG = False

In [7]:
# """ Path & Settings """
# """ For Dataset 2 """

# INPUT_DIR = Path("..") / "input"
# OUTPUT_DIR = Path("/") / "kaggle" / "working"

# O_DATA_ROOT_DIR = INPUT_DIR / "convert-backfintfrecords" / "happy-whale-and-dolphin-backfin"
second_DATA_ROOT_DIR = INPUT_DIR / "/kaggle/input/happywhale-yolo5-cropped-dataset"
second_TRAIN_DIR = second_DATA_ROOT_DIR / "train" / "train_images"
second_TEST_DIR = DATA_ROOT_DIR / "test" / "test_images"
second_TRAIN_CSV_PATH = second_DATA_ROOT_DIR / "train_df.csv"
# SAMPLE_SUBMISSION_CSV_PATH = O_DATA_ROOT_DIR /"sample_submission.csv"
# PUBLIC_SUBMISSION_CSV_PATH = INPUT_DIR / "publicsubmission758" / "submission.csv"
# IDS_WITHOUT_BACKFIN_PATH = INPUT_DIR / "ids-without-backfin" / "ids_without_backfin.npy"

# N_SPLITS = 5

# ENCODER_CLASSES_PATH = OUTPUT_DIR /" encoder_classes.npy"
# TEST_CSV_PATH = OUTPUT_DIR / " test.csv"
# TRAIN_CSV_ENCODED_FOLDED_PATH = OUTPUT_DIR / "train_encoded_folded.csv"
# CHECKPOINTS_DIR = OUTPUT_DIR / "conv-nextday322pl"
# SUBMISSION_CSV_PATH = OUTPUT_DIR / "submission.csv"

# DEBUG = False

In [8]:
""" Make DataFrame & Cross Validation Function """
def get_image_path(id: str, dir: Path) -> str:
    return f"{dir / id}"

def stratifiedkfold(df: pd.DataFrame, cfg) -> pd.DataFrame:
    """ Stratified KFold """
    fold = StratifiedKFold(
        n_splits=cfg.n_folds,
        shuffle=True,
        random_state=cfg.seed
    )
    df['kfold'] = -1
    for num, (tx, vx) in enumerate(fold.split(df, df.individual_id)):
        df.loc[vx, "kfold"] = int(num)
    return df

def load_data(data_path: str, train_path) -> pd.DataFrame:
    """ Load data_folder from csv file like as train.csv, test.csv, val.csv """
    df = pd.read_csv(data_path)
    df["image_path"] = df["image"].apply(get_image_path, dir=train_path)
    return df

def img_preprocess(df: pd.DataFrame, cfg) -> pd.DataFrame:
    """
    For Remove Background Image, Normalize Each Train & Test Data
    
    [Reference]
    https://www.kaggle.com/code/remekkinas/remove-background-salient-object-detection/notebook
    """
    encoder = LabelEncoder()
    df["individual_id"] = encoder.fit_transform(df["individual_id"])
    np.save(ENCODER_CLASSES_PATH, encoder.classes_)
    df = stratifiedkfold(df, cfg) # 폴드 정보 추가는 밖으로 빼서 마지막에 한 번 실행
    df.to_csv(TRAIN_CSV_ENCODED_FOLDED_PATH, index=False)
    return df

In [9]:
deit_train_df = load_data(TRAIN_CSV_PATH, TRAIN_DIR)
deit_train_df

Unnamed: 0,image,species,individual_id,image_path
0,00021adfb725ed.jpg,melon_headed_whale,cadddb1636b9,../input/convert-backfintfrecords/happy-whale-...
1,000562241d384d.jpg,humpback_whale,1a71fbb72250,../input/convert-backfintfrecords/happy-whale-...
2,0007c33415ce37.jpg,false_killer_whale,60008f293a2b,../input/convert-backfintfrecords/happy-whale-...
3,0007d9bca26a99.jpg,bottlenose_dolphin,4b00fe572063,../input/convert-backfintfrecords/happy-whale-...
4,00087baf5cef7a.jpg,humpback_whale,8e5253662392,../input/convert-backfintfrecords/happy-whale-...
...,...,...,...,...
41569,fff54859cb0beb.jpg,false_killer_whale,b90d49ab0905,../input/convert-backfintfrecords/happy-whale-...
41570,fff603f5af8614.jpg,fin_whale,40fe65946167,../input/convert-backfintfrecords/happy-whale-...
41571,fff8b32daff17e.jpg,cuviers_beaked_whale,1184686361b3,../input/convert-backfintfrecords/happy-whale-...
41572,fff94675cc1aef.jpg,blue_whale,5401612696b9,../input/convert-backfintfrecords/happy-whale-...


In [10]:
deit_train_df.image_path[0]

'../input/convert-backfintfrecords/happy-whale-and-dolphin-backfin/train_images/00021adfb725ed.jpg'

In [11]:
""" load & Preprocess Train Data """
""" 두번째 데이터 세트 image 이름 앞에 2를 붙여주자 """
""" 2 붙이는건 어차피 내가 보기 편하려고 하는거라 중요한건 image_path.. 굳이?? """
yolo_train_df = load_data(second_TRAIN_CSV_PATH, second_TRAIN_DIR)
yolo_train_df

Unnamed: 0,image,species,individual_id,image_path
0,00021adfb725ed.jpg,melon_headed_whale,cadddb1636b9,/kaggle/input/happywhale-yolo5-cropped-dataset...
1,000562241d384d.jpg,humpback_whale,1a71fbb72250,/kaggle/input/happywhale-yolo5-cropped-dataset...
2,0007c33415ce37.jpg,false_killer_whale,60008f293a2b,/kaggle/input/happywhale-yolo5-cropped-dataset...
3,0007d9bca26a99.jpg,bottlenose_dolphin,4b00fe572063,/kaggle/input/happywhale-yolo5-cropped-dataset...
4,00087baf5cef7a.jpg,humpback_whale,8e5253662392,/kaggle/input/happywhale-yolo5-cropped-dataset...
...,...,...,...,...
51028,fff639a7a78b3f.jpg,beluga_whale,5ac053677ed1,/kaggle/input/happywhale-yolo5-cropped-dataset...
51029,fff8b32daff17e.jpg,cuviers_beaked_whale,1184686361b3,/kaggle/input/happywhale-yolo5-cropped-dataset...
51030,fff94675cc1aef.jpg,blue_whale,5401612696b9,/kaggle/input/happywhale-yolo5-cropped-dataset...
51031,fffbc5dd642d8c.jpg,beluga_whale,4000b3d7c24e,/kaggle/input/happywhale-yolo5-cropped-dataset...


In [12]:
yolo_train_df.image_path[0]

'/kaggle/input/happywhale-yolo5-cropped-dataset/train/train_images/00021adfb725ed.jpg'

In [13]:
train_df = pd.concat([deit_train_df, yolo_train_df], ignore_index=False)
train_df.reset_index(drop=True, inplace=True)
train_df = img_preprocess(train_df, CFG)
train_df

Unnamed: 0,image,species,individual_id,image_path,kfold
0,00021adfb725ed.jpg,melon_headed_whale,12348,../input/convert-backfintfrecords/happy-whale-...,1
1,000562241d384d.jpg,humpback_whale,1636,../input/convert-backfintfrecords/happy-whale-...,2
2,0007c33415ce37.jpg,false_killer_whale,5842,../input/convert-backfintfrecords/happy-whale-...,2
3,0007d9bca26a99.jpg,bottlenose_dolphin,4551,../input/convert-backfintfrecords/happy-whale-...,4
4,00087baf5cef7a.jpg,humpback_whale,8721,../input/convert-backfintfrecords/happy-whale-...,3
...,...,...,...,...,...
92602,fff639a7a78b3f.jpg,beluga_whale,5520,/kaggle/input/happywhale-yolo5-cropped-dataset...,2
92603,fff8b32daff17e.jpg,cuviers_beaked_whale,1096,/kaggle/input/happywhale-yolo5-cropped-dataset...,0
92604,fff94675cc1aef.jpg,blue_whale,5116,/kaggle/input/happywhale-yolo5-cropped-dataset...,4
92605,fffbc5dd642d8c.jpg,beluga_whale,3909,/kaggle/input/happywhale-yolo5-cropped-dataset...,4


In [14]:
# Use sample submission csv as template
test_df = pd.read_csv(SAMPLE_SUBMISSION_CSV_PATH)
test_df["image_path"] = test_df["image"].apply(get_image_path, dir=TEST_DIR)

test_df.drop(columns=["predictions"], inplace=True)

# Dummy id
test_df["individual_id"] = 0

test_df.to_csv(TEST_CSV_PATH, index=False)

test_df.head()

Unnamed: 0,image,image_path,individual_id
0,000110707af0ba.jpg,../input/convert-backfintfrecords/happy-whale-...,0
1,0006287ec424cb.jpg,../input/convert-backfintfrecords/happy-whale-...,0
2,000809ecb2ccad.jpg,../input/convert-backfintfrecords/happy-whale-...,0
3,00098d1376dab2.jpg,../input/convert-backfintfrecords/happy-whale-...,0
4,000b8d89c738bd.jpg,../input/convert-backfintfrecords/happy-whale-...,0


In [15]:
""" Baseliine Dataset Class """

class HappyWhaleDataset(Dataset):
    def __init__(self, df: pd.DataFrame, transform = False):
        self.df = df
        self.transform = transform
        self.image_names = self.df["image"].values
        self.image_paths = self.df["image_path"].values
        self.targets = self.df["individual_id"].values

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, index: int):
        image_name = self.image_names[index]
        image_path = self.image_paths[index]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        target = self.targets[index]
        target = torch.tensor(target, dtype=torch.long)
        return {"image_name": image_name, "image": image, "target": target}

In [16]:
class LitDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_csv_encoded_folded: str,
        test_csv: str,
        val_fold: float,
        image_size: int,
        batch_size: int,
        num_workers: int,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.train_df = pd.read_csv(train_csv_encoded_folded)
        self.test_df = pd.read_csv(test_csv)
        self.transform = create_transform(
            input_size=(self.hparams.image_size, self.hparams.image_size),
            crop_pct=1.0,
        )
    """ timm create_transform: Auto Normalization (ImageNet) """
    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            # Split train df using fold
            train_df = self.train_df[self.train_df.kfold != self.hparams.val_fold].reset_index(drop=True)
            val_df = self.train_df[self.train_df.kfold == self.hparams.val_fold].reset_index(drop=True)

            self.train_dataset = HappyWhaleDataset(train_df, transform=self.transform)
            self.val_dataset = HappyWhaleDataset(val_df, transform=self.transform)

        if stage == "test" or stage is None:
            self.test_dataset = HappyWhaleDataset(self.test_df, transform=self.transform)

    def train_dataloader(self) -> DataLoader:
        return self._dataloader(self.train_dataset, train=True)

    def val_dataloader(self) -> DataLoader:
        return self._dataloader(self.val_dataset)

    def test_dataloader(self) -> DataLoader:
        return self._dataloader(self.test_dataset)

    def _dataloader(self, dataset: HappyWhaleDataset, train: bool = False) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            shuffle=train,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            drop_last=train,
        )

In [17]:
""" ArcFace Margin Loss """

class ArcMarginProduct(nn.Module):
    """
    Implement of large margin arc distance:
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        s: norm of input feature
        m: margin
        cos(theta + m)
    Reference:
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/blob/master/src/modeling/metric_learning.py
    """

    def __init__(self, in_features: int, out_features: int, s: float, m: float, easy_margin: bool,
        ls_eps: float):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input: torch.Tensor, label: torch.Tensor, device: str = "cuda") -> torch.Tensor:
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        # Enable 16 bit precision
        cosine = cosine.to(torch.float32)

        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output
    
""" Focal Loss: Add Compensate for Well-Inference Target with Original CE """

class FocalLoss(nn.Module):
    """
    Args:
        gamma: reduces the relative loss for well-classified examples
    Reference:
        https://github.com/clcarwin/focal_loss_pytorch/blob/e11e75bad957aecf641db6998a1016204722c1bb/focalloss.py#L6
    """
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            select = (target != 0).type(torch.LongTensor).cuda()
            at = self.alpha.gather(0, select.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [18]:
""" Pytorch Lightening Module """

class LitModule(pl.LightningModule):
    def __init__(
        self,
        checkpoint_path: str,
        model_name: str,
        pretrained: bool,
        drop_rate: float,
        embedding_size: int,
        num_classes: int,
        arc_s: float,
        arc_m: float,
        arc_easy_margin: bool,
        arc_ls_eps: float,
        optimizer: str,
        learning_rate: float,
        weight_decay: float,
        len_train_dl: int,
        epochs:int,   
    ):
        super().__init__()

        self.save_hyperparameters()
          # self.fea_extra_layer = [2, 3]
        self.fea_extra_layer = [-2,-1] # Feature Extract from specific layer: last two layer
        self.model = timm.create_model(
            model_name,
            pretrained=True,
            drop_rate=drop_rate,
            features_only=True,
            out_indices=self.fea_extra_layer # Select Which layer's Embedding Extraction
        )
#        self.model.load_state_dict(torch.load(checkpoint_path),strict=False)
        in_features = 1536 # Last Hidden Layer's Size before fully connected layer
        self.embedding = nn.Sequential(
            nn.Linear(in_features, embedding_size),
            nn.BatchNorm1d(embedding_size)
        )

        # self.model.reset_classifier(num_classes=0, global_pool="avg")
        self.bn = nn.Sequential(
            nn.BatchNorm2d(1024),
            nn.Dropout(0.2),
            nn.AdaptiveAvgPool2d(1),
        )
        self.bn2 = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Dropout(0.2),
            nn.AdaptiveAvgPool2d(1),
        )

        self.arc = ArcMarginProduct(
            in_features=embedding_size,
            out_features=num_classes,
            s=arc_s,
            m=arc_m,
            easy_margin=arc_easy_margin,
            ls_eps=arc_ls_eps,
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """ Extract Embedding from Last Two Layer """
        features = self.model(images)
        features[0] = self.bn2(features[0]) 
        features[1] = self.bn(features[1])
        features = torch.cat(features, dim=1) # 마지막 두 개의 레이어에서 나온 임베딩을 concat한다. 
        embeddings = self.embedding(features.flatten(1))

        return embeddings

    def configure_optimizers(self):
        optimizer = create_optimizer_v2(
            self.parameters(),
            opt=self.hparams.optimizer,
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            self.hparams.learning_rate,
            steps_per_epoch=self.hparams.len_train_dl,
            epochs=self.hparams.epochs,
        )
        scheduler = {"scheduler": scheduler, "interval": "step"}

        return [optimizer], [scheduler]

    def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        return self._step(batch, "train")

    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        return self._step(batch, "val")

    def _step(self, batch: Dict[str, torch.Tensor], step: str) -> torch.Tensor:
        images, targets = batch["image"], batch["target"]

        embeddings = self(images)
        outputs = self.arc(embeddings, targets, self.device)

        loss = self.loss_fn(outputs, targets)
        self.log(f"{step}_loss", loss)

        return loss

In [19]:
""" Pytorch Lightening Trainer """ 

def train(
    checkpoint_path: str,
    train_csv_encoded_folded: str = str(TRAIN_CSV_ENCODED_FOLDED_PATH),
    test_csv: str = str(TEST_CSV_PATH),
    val_fold: float = 0.0,
    image_size: int = 256,
    batch_size: int = 64,
    num_workers: int = 2,
    model_name: str = "tf_efficientnet_b0",
    pretrained: bool = False,
    drop_rate: float = 0.0,
    embedding_size: int = 512,
    num_classes: int = 15587,
    arc_s: float = 64.0,
    arc_m: float = 0.5,
    arc_easy_margin: bool = False,
    arc_ls_eps: float = 0.0,
    optimizer: str = "adam",
    learning_rate: float = 3e-4,
    weight_decay: float = 1e-6,
    checkpoints_dir: str = str(CHECKPOINTS_DIR),
    accumulate_grad_batches: int = 1,
    auto_lr_find: bool = False,
    auto_scale_batch_size: bool = False,
    fast_dev_run: bool = False,
    gpus: int = 1,
    max_epochs: int = 5,
    precision: int = 16,
    stochastic_weight_avg: bool = True,
) -> None:
    pl.seed_everything(42)
    wandb_logger = WandbLogger(
        project='convnext_extract_embedding',
        name='HappyWhale_ArcFace',
        group='Convnext_Extract_Embedding'
    )
    datamodule = LitDataModule(
        train_csv_encoded_folded=train_csv_encoded_folded,
        test_csv=test_csv,
        val_fold=val_fold,
        image_size=image_size,
        batch_size=batch_size,
        num_workers=num_workers,
    )
    
    datamodule.setup()
    len_train_dl = len(datamodule.train_dataloader())

    module = LitModule(
        checkpoint_path,
        model_name=model_name,
        pretrained=pretrained,
        drop_rate=drop_rate,
        embedding_size=embedding_size,
        num_classes=num_classes,
        arc_s=arc_s,
        arc_m=arc_m,
        arc_easy_margin=arc_easy_margin,
        arc_ls_eps=arc_ls_eps,
        optimizer=optimizer,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        len_train_dl=len_train_dl,
        epochs=max_epochs,
    )
    model_checkpoint = ModelCheckpoint(
        checkpoints_dir,
        filename=f"{model_name}_{image_size}",
        monitor="val_loss",
    )
    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=2, verbose=False, mode="min")

    trainer = pl.Trainer(
        logger=wandb_logger,
        accumulate_grad_batches=accumulate_grad_batches,
        auto_lr_find=auto_lr_find,
        auto_scale_batch_size=auto_scale_batch_size,
        benchmark=True,
        callbacks=[model_checkpoint, early_stop_callback],
        deterministic=True,
        fast_dev_run=fast_dev_run,
        gpus=gpus,
        max_epochs=2 if DEBUG else max_epochs,
        precision=precision,
        limit_train_batches=0.1 if DEBUG else 1.0,
        limit_val_batches=0.1 if DEBUG else 1.0,
#         stochastic_weight_avg=stochastic_weight_avg,
    )

    trainer.tune(module, datamodule=datamodule)
    trainer.fit(module, datamodule=datamodule)

In [20]:
""" Let's Train """

model_name = "convnext_base_384_in22ft1k"
image_size = 384
batch_size = 32

train(
    checkpoint_path='/kaggle/input/pytorch-arcface-train-with-focal-loss/conv-nextday322pl/convnext_base_384_in22ft1k_384.ckpt',
    model_name=model_name,
    image_size=image_size,
    batch_size=batch_size,
)

[34m[1mwandb[0m: Currently logged in as: [33mqcqced[0m ([33mlecr_teams[0m). Use [1m`wandb login --relogin`[0m to force relogin


Downloading: "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth" to /root/.cache/torch/hub/checkpoints/convnext_base_22k_1k_384.pth


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]