In [None]:
import torch
import torch.nn as nn
import timm
from torchvision import transforms
import os
from os import listdir
from os.path import splitext
import numpy as np
from PIL import Image
import pickle as pk
from tqdm import tqdm

IMAGES_PATH="./../images"
def read_img_file(f):
    img = Image.open(f)
    # img=img.convert('L').convert('RGB') #GREYSCALE
    if img.mode != 'RGB':
        img = img.convert('RGB')
    return img

device = "cuda" if torch.cuda.is_available() else "cpu"
model = timm.create_model('beit_base_patch16_224_in22k', pretrained=True)

# model.global_pool=timm.models.layers.SelectAdaptivePool2d(pool_type="max", flatten=True) # for resnet models (cant remember which one, sry :(  )
# model.fc=nn.Identity()

model.head=nn.Identity()

# model.head.global_pool=timm.models.layers.SelectAdaptivePool2d(pool_type="max", flatten=True) # for resnet models (cant remember which one, sry :(  )
# model.head.fc=nn.Identity()
# model.head.flatten=nn.Identity()
# print(model)

model.eval()
model.to(device)

IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]

IMAGENET_DEFAULT_MEAN_21k = [0.5, 0.5, 0.5]
IMAGENET_DEFAULT_STD_21k = [0.5, 0.5, 0.5]

_transform=transforms.Compose([
                       transforms.Resize((224,224)),
                      #  transforms.CenterCrop(224),
                       transforms.ToTensor(),
                       transforms.Normalize(IMAGENET_DEFAULT_MEAN_21k,IMAGENET_DEFAULT_STD_21k)])


def transform(im):
  return _transform(im)
# def transform(im):
#   desired_size = 224
#   old_size = im.size  # old_size[0] is in (width, height) format
#   ratio = float(desired_size)/max(old_size)
#   new_size = tuple([int(x*ratio) for x in old_size])
#   im = im.resize(new_size, Image.ANTIALIAS)
#   new_im = Image.new("RGB", (desired_size, desired_size))
#   new_im.paste(im, ((desired_size-new_size[0])//2, (desired_size-new_size[1])//2))
#   return _transform(new_im)

first_img=True
def get_features(f):
    global first_img
    img = read_img_file(f)
    image = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model(image).cpu().numpy()[0]
    # print(first_img)
    if first_img:
      first_img=False
      print(image_features.shape)
    return image_features


def generate_resnet_features():
    image_filenames=listdir(IMAGES_PATH)
    image_filenames.sort()
    # image_ids=set(map(lambda el: splitext(el)[0],image_filenames))
    image_filenames_batches = [image_filenames[i:i + 100000] for i in range(0, len(image_filenames), 100000)]
    batch_num=1
    for batch in image_filenames_batches:
      save_path = f"./beit_fips_224x224_color_{batch_num*100}k.pkl"
      if os.path.exists(save_path):
         print(f"{save_path} exists, skipping")
         batch_num+=1
         continue
      all_image_features=[]
      for image_filename in tqdm(batch):
          image_id=splitext(image_filename)[0]
          try:
            image_features=get_features(IMAGES_PATH+"/"+image_filename)
            all_image_features.append({'image_id':image_id,'features':image_features})
          except:
            print("something is wrong!!!",image_filename)
            pass
          # print(image_features.dtype)
      pk.dump(all_image_features, open(save_path,"wb"))
      batch_num+=1

generate_resnet_features()