In [2]:
import timm
import re, gc
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer
import math
import dataset_class.dataclass as dataset_class
import model.metric as model_metric
import model.metric_learning as metric_learning
import model.model as model_arch
from torch.utils.data import DataLoader
from dataset_class.data_preprocessing import *
from utils.helper import *
from trainer.trainer_utils import *
from model.metric import *
from tqdm.auto import tqdm
import rasterio
from rasterio.enums import Resampling
from torch.utils.data import Dataset
from transformers import AutoProcessor, CLIPImageProcessor
import albumentations as albumentations
from albumentations.pytorch import ToTensorV2

In [3]:
class CFG:
    """ Pipeline Setting """
    train, test = True, False
    checkpoint_dir = './saved/model'
    resume, load_pretrained,  state_dict = False, False, '/'
    name = 'FBP3_Base_Train_Pipeline'
    loop = 'SD2Trainer'
    dataset = 'SD2Dataset'  # dataset_class.dataclass.py -> FBPDataset, MPLDataset
    model_arch = 'SD2Model'  # model.model.py -> FBPModel, MPLModel
    style_model_arch = 'StyleExtractModel'  # model.model.py -> StyleModel
    style_model = 'convnext_base_384_in22ft1k'

    """ Common Options """
    wandb = True
    optuna = False  # if you want to tune hyperparameter, set True
    competition = 'FB3'
    seed = 42
    cfg_name = 'CFG'
    n_gpu = 1
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    gpu_id = 0
    num_workers = 4
    
    """ Data Options """
    batch_size = 256

In [4]:
import torch, os, sys, random, json
import numpy as np


def check_device() -> bool:
    return torch.mps.is_available()


def check_library(checker: bool) -> tuple:
    """
    1) checker == True
        - current device is mps
    2) checker == False
        - current device is cuda with cudnn
    """
    if not checker:
        _is_built = torch.backends.cudnn.is_available()
        _is_enable = torch.backends.cudnn.enabledtorch.backends.cudnn.enabled
        version = torch.backends.cudnn.version()
        device = (_is_built, _is_enable, version)
        return device


def class2dict(cfg) -> dict:
    return dict((name, getattr(cfg, name)) for name in dir(cfg) if not name.startswith('__'))


def all_type_seed(cfg, checker: bool) -> None:
    # python & torch seed
    os.environ['PYTHONHASHSEED'] = str(cfg.seed)  # python Seed
    random.seed(cfg.seed)  # random module Seed
    np.random.seed(cfg.seed)  # numpy module Seed
    torch.manual_seed(cfg.seed)  # Pytorch CPU Random Seed Maker

    # device == cuda
    if not checker:
        torch.cuda.manual_seed(cfg.seed)  # Pytorch GPU Random Seed Maker
        torch.cuda.manual_seed_all(cfg.seed)  # Pytorch Multi Core GPU Random Seed Maker
        # torch.cudnn seed
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.enabled = True

    # devide == mps
    else:
        torch.mps.manual_seed(cfg.seed)


def seed_worker(worker_id) -> None:
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    
check_library(True)
all_type_seed(CFG, True)
g = torch.Generator()
g.manual_seed(CFG.seed)

<torch._C.Generator at 0x7f727b225c50>

In [5]:
def load_data(data_path: str) -> pd.DataFrame:
    """
    Load data_folder from csv file like as train.csv, test.csv, val.csv
    """
    df = pd.read_csv(data_path)
    return df

In [6]:
class SD2Dataset:
    """ Image, Prompt Dataset For OpenAI CLIP Pipeline """
    def __init__(self, cfg, df: pd.DataFrame) -> None:
        self.cfg = cfg
        self.df = df
        self.img_transform = albumentations.Compose([
            albumentations.Resize(384, 384),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2()]
        )

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, item) -> tuple[Tensor, Tensor]:
        """
        No need to tokenize text, CLIP has its own tokenizer stage in model class (encode text)
        return:
            image: image for style-extractor
            clip_image: image for CLIP
            target: prompt for CLIP
        """
        image_index = self.df.iloc[item, 0]
        image = rasterio.open(self.df.iloc[item, 0])
        tensor_image = image.read(resampling=Resampling.bilinear).transpose(1, 2, 0)
        style_image = self.img_transform(image=tensor_image)['image']  # resize & normalize for style-extractor
        return image_index, style_image

In [7]:
class StyleExtractModel(nn.Module):
    """
    Model class for Style-Extractor Model (EfficientNet, Convnext, ResNet, ...etc)
    Style-Extractor Model is used for extract style feature(background) from image
    And then, Feature will be concatenated with CLIP's Image embedding
    This Model is used ONLY extracting embedding, just Only forward pass

    In CLIP Model's Code in Huggingface, AutoProcessor do center crop to image in resizing 224x224
    But in many prompt sentences, they have a lot of word for background called feature.
    So we need to style-extractor for more good performance in generate prompt text
    option:
        style_model: efficientnet family, convnext_base, resent family
        efficientnet: pass keyword 'blocks' to forward function
        convnext_base: pass keyword 'stage' to forward function
        resnet: pass keyword 'layer1 ~ layer4' to forward function

    [Reference]
    https://www.kaggle.com/code/tanreinama/style-extract-from-vgg-clip-object-extract
    """
    def __init__(self, cfg) -> None:
        super().__init__()
        self.cfg = cfg
        self.style_model = timm.create_model(
            self.cfg.style_model,
            pretrained=True,
            features_only=False,  # will be drop classifier or regression head
        )
        self.avg = nn.AdaptiveAvgPool1d(1)
        if 'efficientnet' in self.cfg.style_model:
            layer_name = 'blocks'
        elif 'convnext' in self.cfg.style_model:
            layer_name = 'stages'
        elif 'resnet' in self.cfg.style_model:
            layer_name = ['layer1', 'layer2', 'layer3', 'layer4']
        self.feature1 = self.style_model.stem + self.style_model.stages[0:1]
        self.feature2 = self.style_model.stages[1:2]
        self.feature3 = self.style_model.stages[2:3]
        self.feature4 = self.style_model.stages[3:4]

    @staticmethod
    def gram_matrix(x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.size()
        f = x.view(b, c, h * w)
        g = torch.bmm(f, f.transpose(1, 2)) / (h * w)
        return g

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embedding1 = self.feature1(x)
        embedding2 = self.feature2(embedding1)
        embedding3 = self.feature3(embedding2)
        embedding4 = self.feature4(embedding3)

        g1 = self.gram_matrix(embedding1)
        g2 = self.gram_matrix(embedding2)
        g3 = self.gram_matrix(embedding3)
        g4 = self.gram_matrix(embedding4)
        g = [self.avg(g1).squeeze(2), self.avg(g2).squeeze(2), self.avg(g3).squeeze(2), self.avg(g4).squeeze(2)]
        return torch.cat(g, dim=1)

In [8]:
class SD2Trainer:
    """ For OpenAI CLIP Fine-Tuned Pipeline with Multiple Negative Ranking Loss, Style-Extractor """
    def __init__(self, cfg, generator) -> None:
        self.cfg = cfg
        self.generator = generator
        self.df = pd.read_csv('./dataset_class/final_downsample_prompt.csv')
    def make_batch(self):
        """ Make Batch Dataset for main train loop """
        # Custom Datasets
        train_dataset = SD2Dataset(self.cfg, self.df)

        # DataLoader
        loader_train = DataLoader(
            train_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=True,
            worker_init_fn=seed_worker,
            generator=self.generator,
            num_workers=self.cfg.num_workers,
            pin_memory=True,
            drop_last=False,
        )
        return loader_train

    def model_setting(self):
        """ set train & validation options for main train loop """
        style_model = StyleExtractModel(self.cfg)       
        style_model.to(self.cfg.device)
        return style_model

    # Train Function
    def train_fn(self, loader_train, style_model):
        """ Training Function """
        image_list, embedding_list = [], []
        torch.autograd.set_detect_anomaly(True)
        style_model.eval()
        for step, (image_index, style_image) in enumerate(tqdm(loader_train)):            
            style_image = style_image.to(self.cfg.device)  # style image to GPU
            with torch.no_grad():
                style_features = style_model(style_image)  # style image to style feature
                
            style_features = style_features.detach().cpu().numpy()
            embedding_list.append(style_features)
            image_list.append(image_index)
            
        return image_list, embedding_list

In [9]:
def train_loop(cfg: any) -> None:
    """ Base Trainer Loop Function """
    train_input = SD2Trainer(cfg, g)
    loader_train = train_input.make_batch()
    style_model = train_input.model_setting()
    image_list, embedding_list = train_input.train_fn(
        loader_train, style_model
    )
    torch.save(image_list, 'style_image_name.pth')
    torch.save(embedding_list, 'style_embedding_name.pth')    

In [10]:
train_loop(CFG)

  0%|          | 0/853 [00:00<?, ?it/s]

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
