In [None]:
import os
import random 
import pandas as pd
import numpy as np
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm

import timm
from timm.utils import AverageMeter

from sklearn.preprocessing import normalize

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms

from transformers import AutoProcessor, AutoModel, AutoConfig

import sys
sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer

import warnings
warnings.filterwarnings('ignore')


In [None]:
class CFG:
    model_path = '/kaggle/input/my-dataset-of-conv/clip-vit-large-patch14_epoch2-3.pth'
    model_config='/kaggle/input/my-clip-model/clip-vit-large-patch14-openai/config.json'
    model_name = 'openai/clip-vit-large-patch14'
    input_size = 224
    batch_size = 64

In [None]:
class DiffusionTestDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        image = self.transform(image)
        return image

In [None]:
class Net(nn.Module):
    def __init__(self, model_config):
        super(Net, self).__init__()
        config = AutoConfig.from_pretrained(model_config, local_files_only=True)
        clip = AutoModel.from_config(config)
        self.vision = clip.vision_model
        self.fc = nn.Linear(1024, 384)

    def forward(self, x):
        out = self.vision(x)['pooler_output']
        return self.fc(out)

In [None]:
def predict(
    images,
    model_path,
    model_config,
    model_name,
    input_size,
    batch_size
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    dataset = DiffusionTestDataset(images, transform)
    dataloader = DataLoader(
        dataset=dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=False
    )

    model = Net(model_config)
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    tta_preds = None
    for _ in range(2):
        preds = []
        for X in tqdm(dataloader, leave=False):
            X = X.to(device)

            with torch.no_grad():
                X_out = model(X).cpu().numpy()
                X_out = X_out / ( np.abs(X_out).max(axis=-1, keepdims=True) + 0.0000001)
                X_out = normalize( X_out )
                preds.append(X_out)
                
        if tta_preds is None:
            tta_preds = np.vstack(preds).flatten()
        else:
            tta_preds += np.vstack(preds).flatten()
    
    return tta_preds / 2

In [None]:
images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
embeddings = predict(images, CFG.model_path, CFG.model_config, CFG.model_name, CFG.input_size, CFG.batch_size)