In [None]:
!pip install transformers timm fairscale datasets
!git clone https://github.com/salesforce/BLIP

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
local_dataset_folder = "./historic_img/"
name_for_dataset = "historic_images"

In [None]:
%cd /content/BLIP
!mkdir $local_dataset_folder

In [None]:
import datasets
from datasets import Dataset
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict

import sys
import os
from PIL import Image
import requests
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from models.blip import blip_decoder

torch_device = None
transform = None
model = None
data = None


In [None]:
def setup(image_size):
    global torch_device
    global transform
    global model

    torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])

    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda")
    model.eval()
    model = model.to(torch_device)

def load_image_for_blip_inference(img_path):
    raw_image = Image.open(img_path).convert('RGB')
    image = transform(raw_image).unsqueeze(0).to(torch_device)
    return image

def blip_caption_for_image(img_path):
    image = load_image_for_blip_inference(img_path)
    caption = [""]

    with torch.no_grad():
        caption = model.generate(image, sample=True, top_p=0.9, max_length=64, min_length=5)
        #print('caption: '+caption[0])
    return caption[0]


In [None]:
def process_image_folder(folder_name, dataset_name):
    global data

    files = list(Path(folder_name).rglob("*.jpg"))
    data = defaultdict(list)

    for file in tqdm(files):
        image = Image.open(str(file))
        data['image'].append(image)
        text = blip_caption_for_image(file)
        data['text'].append(text)

    dataset = Dataset.from_dict(data)

    print("Dataset was generated.")
    print(dataset)
    print(dataset[0])
    dataset.save_to_disk(dataset_name)
    print("Dataset saved to disk.")

In [None]:
def push_to_hub():
    hf_user = ""
    hf_token = ""
    dataset = Dataset.load_from_disk(name_for_dataset)
    remote_hub_repo = hf_user + "/" + name_for_dataset
    dataset.push_to_hub(remote_hub_repo, token=hf_token, private=False)

In [None]:
  image_size_for_blip_inference = 384
  setup(image_size_for_blip_inference)
  process_image_folder(local_dataset_folder,name_for_dataset)
  push_to_hub()