In [6]:
import sys, os, importlib, requests, utils, pandas as pd, numpy as np, torch
from PIL import Image
from transformers import ViTFeatureExtractor, ViTModel
from tqdm import tqdm

importlib.reload(utils);

In [3]:
# feature extraction model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_name = utils.model_config()['feature-extraction']['image']['model-name']
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTModel.from_pretrained(model_name).to(device)

In [4]:
df = pd.read_csv("Y:/Internship/SoMin/data/df_cpm_scaled_valid_img2.csv")

In [5]:
df.shape

(30432, 49)

In [5]:
df.iloc[0]

Unnamed: 0                                                           0
search_term                                                    ad-tech
country                                                             US
page_id                                                100470349235347
page_name                                                   Save Texas
ad_id                                                  762749274711708
ad_creation_date                                            2022-05-03
ad_creation_month                                                    5
delivery_start                                              2022-05-03
delivery_stop                                               2022-05-06
delivery_period                                                      3
ad_url               https://www.facebook.com/ads/archive/render_ad...
ad_body              In a technical report updated for the first ti...
uses_multi_body                                                      0
link_c

In [7]:
filepath = "Y:/Internship/SoMin/Image/762749274711708.png"
image = Image.open(filepath)

In [8]:
inputs = feature_extractor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
emb_array = outputs.last_hidden_state.cpu()[0][0].detach().numpy()

In [10]:
emb_array.shape

(768,)

In [11]:
ad_ids, embeddings = [], []
IMG_DIR = "Y:/Internship/SoMin/Image"
for _, row in tqdm(df.iterrows()):
    
    ad_id = row['ad_id']
    try:
        # image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
        image = Image.open(f"{IMG_DIR}/{ad_id}.png")
        inputs = feature_extractor(images=image, return_tensors="pt").to(device)
        outputs = model(**inputs) # https://github.com/huggingface/transformers/issues/2704
        emb_array = outputs.last_hidden_state.cpu()[0][0].detach().numpy()#.reshape(1, -1)
        ad_ids.append(ad_id)
        embeddings.append(emb_array.tolist())
        
    except:
        pass
    # embeddings.append(','.join(list(map(str, emb_array))))

30432it [4:50:36,  1.75it/s]


In [12]:
image_features = pd.concat([pd.DataFrame(ad_ids, columns=['ad_id']), 
                            pd.DataFrame(embeddings, columns=[f'img{i}' for i in range(len(embeddings[0]))])], 
                           axis=1)

In [13]:
# clean_dir = os.path.abspath('..').replace('\\', '/') + f'/dataset/clean'
image_features.to_csv('Y:/Internship/SoMin/clean/image_features_trial1_0206.csv', index=False)
image_features.to_pickle('Y:/Internship/SoMin/clean/image_features_trial1_0206.pkl')

In [15]:
image_features

Unnamed: 0,ad_id,img0,img1,img2,img3,img4,img5,img6,img7,img8,...,img758,img759,img760,img761,img762,img763,img764,img765,img766,img767
0,762749274711708,0.164756,-0.252698,-0.074716,0.068187,-0.131181,0.088666,-0.025290,-0.155478,-0.235019,...,0.465630,0.002484,-0.085118,0.052690,-0.236282,-0.064323,0.325325,-0.202712,-0.088441,-0.287930
1,563054608471314,0.164756,-0.252698,-0.074716,0.068187,-0.131181,0.088666,-0.025290,-0.155478,-0.235019,...,0.465630,0.002484,-0.085118,0.052690,-0.236282,-0.064323,0.325325,-0.202712,-0.088441,-0.287930
2,930566127614072,0.234983,0.027089,0.022409,0.134131,0.120969,-0.017743,0.069042,-0.146920,0.047015,...,-0.179992,0.032911,0.030026,-0.294595,-0.231033,-0.003755,-0.199238,-0.027423,-0.078111,-0.131632
3,738864343939155,0.296482,-0.013693,-0.392394,-0.016913,-0.093951,0.216531,-0.167886,-0.022603,-0.021857,...,-0.146833,-0.075014,-0.054048,-0.222878,0.120308,-0.341486,-0.292918,-0.210843,0.045422,-0.129597
4,512240050437223,0.197882,-0.007214,-0.138495,-0.121283,-0.311274,0.164506,-0.133318,-0.233321,-0.124335,...,-0.021403,0.025870,-0.075021,-0.094134,0.051662,-0.187134,-0.271364,-0.129058,-0.115162,0.132305
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29756,381289572628044,0.124914,-0.258586,-0.085897,-0.063307,0.093259,0.113988,0.029446,-0.234514,0.251592,...,0.112608,-0.019068,-0.082013,-0.202433,-0.156729,0.015124,-0.232641,-0.035432,-0.121599,-0.171165
29757,517753845398018,-0.203471,0.036113,-0.009781,-0.145538,-0.194277,-0.338150,0.007197,0.028206,-0.002539,...,-0.043676,-0.007365,-0.020420,-0.408333,-0.079077,0.191462,-0.339416,-0.337108,-0.168544,-0.264438
29758,1251766551633027,0.010495,-0.067719,-0.139276,-0.080184,0.122595,-0.045484,-0.114048,-0.163266,0.065660,...,-0.036252,-0.033005,-0.099844,-0.245278,-0.182185,-0.055517,-0.100192,-0.015562,-0.046616,0.002111
29759,205891737021220,0.126666,-0.238716,0.088393,-0.064244,0.000469,-0.170381,0.068976,-0.125390,0.138971,...,0.123096,-0.014946,-0.132246,-0.349340,-0.189837,0.163440,-0.058997,-0.095573,-0.131043,0.032312
