In [11]:
import gc
import torch
import numpy as np
import pandas as pd
import rasterio
from rasterio.enums import Resampling
from torch.utils.data import Dataset
from torch import Tensor
from transformers import AutoProcessor, CLIPImageProcessor
from tqdm.auto import tqdm

In [8]:
class CFG:
    """ Pipeline Setting """
    train, test = True, False
    checkpoint_dir = './saved/model'
    resume, load_pretrained,  state_dict = False, False, '/'
    name = 'FBP3_Base_Train_Pipeline'
    loop = 'SD2Trainer'
    dataset = 'SD2Dataset'  # dataset_class.dataclass.py -> FBPDataset, MPLDataset
    model_arch = 'SD2Model'  # model.model.py -> FBPModel, MPLModel
    style_model_arch = 'StyleExtractModel'  # model.model.py -> StyleModel
    style_model = 'convnext_base_384_in22ft1k'
    model = 'openai/clip-vit-large-patch14'

    """ Common Options """
    wandb = True
    optuna = False  # if you want to tune hyperparameter, set True
    competition = 'FB3'
    seed = 42
    cfg_name = 'CFG'
    n_gpu = 1
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    gpu_id = 0
    num_workers = 4
    
    """ Data Options """
    batch_size = 256

In [9]:
df = pd.read_csv('./dataset_class/final_final_prompt.csv')
df

Unnamed: 0,image_name,prompt,fold,style_index
0,./dataset_class/data_folder/SDDB2_2M_75/SDDB2-...,"only memories remain, trending on artstation",1,149562
1,./dataset_class/data_folder/SDDB2_2M_75/SDDB2-...,a painting by edward hopper of a group of men ...,2,209980
2,./dataset_class/data_folder/SDDB2_2M_75/SDDB2-...,a painting by edward hopper of a glowing human...,3,160493
3,./dataset_class/data_folder/SDDB2_2M_75/SDDB2-...,a painting by edward hopper of the angel of de...,2,194815
4,./dataset_class/data_folder/SDDB2_2M_75/SDDB2-...,a painting by edward hopper of a group of four...,0,193190
...,...,...,...,...
218197,./dataset_class/data_folder/ChatGPT_SDDB2/chat...,A vintage photograph of a jazz band made up of...,4,177672
218198,./dataset_class/data_folder/ChatGPT_SDDB2/chat...,"A vibrant collage depicting a cyborg giraffe, ...",2,171345
218199,./dataset_class/data_folder/ChatGPT_SDDB2/chat...,A minimalist ink drawing of a yeti and a polar...,2,195483
218200,./dataset_class/data_folder/ChatGPT_SDDB2/chat...,A surrealist portrayal of a dreamy sky filled ...,1,82530


In [None]:
image_processor = CLIPImageProcessor.from_pretrained(CFG.model)
image_list, pixel_value_list = [], []
for i in tqdm(range(len(df))):
    image_index = df.iloc[i, 0]
    image = rasterio.open(df.iloc[i, 0])
    tensor_image = image.read(resampling=Resampling.bilinear).transpose(1, 2, 0)
    clip_image = image_processor(tensor_image)['pixel_values']  
    
    image_list.append(image_index)
    pixel_value_list.append(clip_image)

  0%|          | 0/218202 [00:00<?, ?it/s]

In [None]:
torch.save(image_list, 'clip_image_list.pth')
torch.save(pixel_value_list, 'clip_pixel.pth')

In [None]:
df = pd.read_csv('./dataset_class/final_final_prompt.csv')
df['clip_index'] = -1

In [None]:
clip_name_list = torch.load('lip_image_list.pth')
clip_pixel_list = torch.load('clip_pixel.pth')

In [None]:
new_name, pixel_idx, real_pixel = [], [], []
embedding_dict = {}
count = 0
for i in range(len(clip_name_list)):
    for j in range(len(clip_name_list[i])):
        new_name.append(clip_name_list[i][j])
        real_pixel.append(clip_pixel_list[i][j])
        pixel_idx.append(count)
        count += 1