In [None]:
'''
File to get the clip embeddings of both the train and test datasets and save as csv
'''
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import torch
from torch.utils.data import DataLoader, random_split
from lightgbm import LGBMRegressor
from model_data import PriceModel, PriceDataset, collate_function, SMAPELoss, LogCoshLoss
from tqdm import tqdm
import gc

In [3]:
batch_size=128
clip_model = PriceModel(checkpoint='google/siglip2-large-patch16-256', cache_dir = 'hf_models')
clip_model.load_checkpoint('output/best.pth')
dataset = PriceDataset(annotations_file='student_resource/dataset/train_focused.csv', 
                       image_dir = 'train_images', content_col='focused_sentence')
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, collate_fn = collate_function)

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
# check_dataset = torch.utils.data.Subset(train_dataset, list(range(500)))
# check_loader = DataLoader(dataset=check_dataset, batch_size=batch_size, shuffle=False, collate_fn = collate_function)

In [5]:
torch.cuda.empty_cache()

In [32]:
def get_clip_data(clip_model, loader):
    embeddings = None
    prices = None
    content = []

    clip_model.eval()
    with torch.inference_mode():
        for i, batch in tqdm(enumerate(loader), total=len(loader)):
            images, texts, targets = batch
            inputs = clip_model.processor(images=images, text=texts, return_tensors="pt", truncation=True, padding="max_length").to(clip_model.backbone.device)
            image_features = clip_model.backbone.get_image_features(pixel_values = inputs['pixel_values'])
            text_features = clip_model.backbone.get_text_features(input_ids = inputs['input_ids'])
            batch_embeddings = torch.cat([image_features, text_features], dim=-1).cpu()
            if embeddings is None:
                embeddings = batch_embeddings
            else:
                embeddings = torch.cat([embeddings, batch_embeddings], dim=0)
            if prices is None:
                prices = targets
            else:
                prices = torch.cat([prices, targets], dim=0)

            # content.extend(texts)

    return embeddings, prices, content

    
    
                
    

In [8]:
embeddings, prices, content = get_clip_data(clip_model, loader)

100%|██████████████████████████████████████████████████████████████████████████████| 575/575 [1:13:00<00:00,  7.62s/it]


In [9]:
embeddings.shape, prices.shape, len(content)

(torch.Size([73499, 2048]), torch.Size([73499]), 73499)

In [25]:
train_data = pd.DataFrame(embeddings)
train_data.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
0,0.17402,-0.143133,-0.242676,-0.533552,0.782315,1.973485,0.928428,0.253139,0.6613,-0.702788,...,-0.461073,-0.125172,0.848946,-1.164739,0.02556,0.019596,0.232899,-0.708067,1.323833,0.778812
1,-1.044623,-0.257172,-0.450125,-1.454822,1.349169,2.260221,0.853942,1.736883,1.243835,-1.209577,...,-0.774796,-0.531549,0.408175,-1.389376,-0.290974,0.27216,0.05591,-0.909074,1.379296,0.838086
2,0.57262,0.231423,-0.66519,-0.891878,0.52189,0.779511,1.5059,0.247009,-0.148506,-1.666485,...,-0.396459,-0.253824,1.169248,-1.005889,-0.05043,0.145862,0.427845,-0.818516,1.431625,0.82927
3,-1.323858,1.099959,-1.524922,-0.707894,1.355102,2.585816,2.315762,3.067365,1.370459,-0.887541,...,-0.711265,-0.392438,0.868274,-1.389667,-0.025847,0.299311,0.401773,-1.306074,1.137081,1.183864
4,-1.899729,0.210678,-0.451701,-0.719948,1.165255,2.497567,2.356846,1.849354,1.890857,-1.406322,...,-0.277977,-0.530559,-0.226734,-0.72477,-1.245805,-0.022336,-0.252454,-1.618746,1.719296,1.287199


In [26]:
train_data['prices'] = prices
train_data.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2039,2040,2041,2042,2043,2044,2045,2046,2047,prices
0,0.17402,-0.143133,-0.242676,-0.533552,0.782315,1.973485,0.928428,0.253139,0.6613,-0.702788,...,-0.125172,0.848946,-1.164739,0.02556,0.019596,0.232899,-0.708067,1.323833,0.778812,1.773256
1,-1.044623,-0.257172,-0.450125,-1.454822,1.349169,2.260221,0.853942,1.736883,1.243835,-1.209577,...,-0.531549,0.408175,-1.389376,-0.290974,0.27216,0.05591,-0.909074,1.379296,0.838086,2.647592
2,0.57262,0.231423,-0.66519,-0.891878,0.52189,0.779511,1.5059,0.247009,-0.148506,-1.666485,...,-0.253824,1.169248,-1.005889,-0.05043,0.145862,0.427845,-0.818516,1.431625,0.82927,1.088562
3,-1.323858,1.099959,-1.524922,-0.707894,1.355102,2.585816,2.315762,3.067365,1.370459,-0.887541,...,-0.392438,0.868274,-1.389667,-0.025847,0.299311,0.401773,-1.306074,1.137081,1.183864,3.444895
4,-1.899729,0.210678,-0.451701,-0.719948,1.165255,2.497567,2.356846,1.849354,1.890857,-1.406322,...,-0.530559,-0.226734,-0.72477,-1.245805,-0.022336,-0.252454,-1.618746,1.719296,1.287199,4.211979


In [27]:
train_data.to_csv('train_clip_embed.csv', index=False)

In [29]:
del train_data
gc.collect()

764

In [30]:
dataset = PriceDataset(annotations_file='student_resource/dataset/test_focused.csv', 
                       image_dir = 'test_images', content_col='focused_sentence', return_target=False)
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, collate_fn = collate_function)

In [31]:
torch.cuda.empty_cache()

In [33]:
embeddings, prices, content = get_clip_data(clip_model, loader)

100%|██████████████████████████████████████████████████████████████████████████████| 586/586 [1:21:54<00:00,  8.39s/it]


In [34]:
embeddings.shape, prices.shape, len(content)

(torch.Size([75000, 2048]), torch.Size([75000]), 0)

In [35]:
test_data = pd.DataFrame(embeddings)
test_data.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
0,-1.34876,0.32193,-1.185519,-1.311213,2.139934,1.840348,2.206761,1.284696,0.920878,-0.815361,...,-0.535424,-0.162835,0.537928,-1.464591,-0.306322,1.110158,-0.007308,-1.558966,1.406268,0.887152
1,-1.027735,0.954196,-1.551497,-1.714152,1.232652,1.91899,2.447417,1.904164,0.606909,-1.267657,...,-0.374746,-0.045907,0.837247,-1.404851,0.123945,1.15284,0.325666,-0.736876,0.973377,0.994311
2,-0.955755,-0.640269,-1.076678,0.978837,0.541126,1.586141,2.662478,2.379204,0.993017,-1.366224,...,-0.753336,-0.414731,0.505255,-1.105734,-0.891332,0.667471,0.047161,-1.489766,1.431225,0.953298
3,0.151534,0.322207,-1.154273,-0.352811,0.178547,0.760418,1.133446,-0.24393,0.610778,-0.542478,...,-0.585725,-0.376377,0.916929,-1.099116,-0.982928,0.318741,0.430117,-1.1909,1.347082,1.114862
4,-0.934557,0.206794,-1.458136,-1.488192,0.629494,1.955989,2.185993,0.941259,1.750161,-1.073258,...,-0.718887,-0.303363,0.933871,-1.38036,0.081186,0.506747,0.21265,-1.234667,1.046096,1.404367


In [37]:
test_data['prices'] = prices
test_data.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2039,2040,2041,2042,2043,2044,2045,2046,2047,prices
0,-1.34876,0.32193,-1.185519,-1.311213,2.139934,1.840348,2.206761,1.284696,0.920878,-0.815361,...,-0.162835,0.537928,-1.464591,-0.306322,1.110158,-0.007308,-1.558966,1.406268,0.887152,0.0
1,-1.027735,0.954196,-1.551497,-1.714152,1.232652,1.91899,2.447417,1.904164,0.606909,-1.267657,...,-0.045907,0.837247,-1.404851,0.123945,1.15284,0.325666,-0.736876,0.973377,0.994311,0.0
2,-0.955755,-0.640269,-1.076678,0.978837,0.541126,1.586141,2.662478,2.379204,0.993017,-1.366224,...,-0.414731,0.505255,-1.105734,-0.891332,0.667471,0.047161,-1.489766,1.431225,0.953298,0.0
3,0.151534,0.322207,-1.154273,-0.352811,0.178547,0.760418,1.133446,-0.24393,0.610778,-0.542478,...,-0.376377,0.916929,-1.099116,-0.982928,0.318741,0.430117,-1.1909,1.347082,1.114862,0.0
4,-0.934557,0.206794,-1.458136,-1.488192,0.629494,1.955989,2.185993,0.941259,1.750161,-1.073258,...,-0.303363,0.933871,-1.38036,0.081186,0.506747,0.21265,-1.234667,1.046096,1.404367,0.0


In [38]:
test_data.to_csv('test_clip_embed.csv', index=False)