In [None]:
import os
import gc
import time
import glob

import numpy as np
import pandas as pd
from PIL import Image, ImageOps
from pathlib import Path
from tqdm.notebook import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms

from sklearn.preprocessing import normalize
from sklearn.metrics.pairwise import cosine_similarity

import timm
from sentence_transformers import SentenceTransformer

gc.collect()

In [None]:
class CFG_CONV:
    model_path = '/kaggle/input/my-large-vit-path/conv_next_epoch2-3.pth'
    model_name = 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384'  
    input_size = 384
    batch_size = 64

In [None]:
class DiffusionTestDatasetConv(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        image = self.transform(image)
        return image

In [None]:
def predict_conv(
    images,
    model_path,
    model_name,
    input_size,
    batch_size
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    dataset = DiffusionTestDatasetConv(images, transform)
    dataloader = DataLoader(
        dataset=dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=False
    )
    
    model = timm.create_model(model_name, pretrained=False, num_classes=384)
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    tta_preds = None
    for _ in range(2):
        preds = []
        for X in tqdm(dataloader, leave=False):
            X = X.to(device)

            with torch.no_grad():
                X_out = model(X).cpu().numpy()
                X_out = X_out / ( np.abs(X_out).max(axis=-1, keepdims=True) + 0.0000001)
                X_out = normalize( X_out )
                preds.append(X_out)
                
        if tta_preds is None:
            tta_preds = np.vstack(preds).flatten()
        else:
            tta_preds += np.vstack(preds).flatten()
    
    return tta_preds / 2

In [None]:
images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
embeddings = predict_conv(images, CFG_CONV.model_path, CFG_CONV.model_name, CFG_CONV.input_size, CFG_CONV.batch_size)