# Antarctic Prompt


## Setup Gdrive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Antarctic Prompt Setup

In [None]:

antarctic_number_of_captions = 1#@param {type: 'number'}

!git clone https://github.com/dzryk/antarctic-captions.git
%cd antarctic-captions/
!git clone https://github.com/openai/CLIP
!pip3 install gdown
!pip3 install ftfy
!pip3 install transformers
#!pip3 install git+https://github.com/PyTorchLightning/pytorch-lightning
!pip install torch pytorch-lightning
# Download models and cache

!wget -m -np -c -U "eye02" -w 2 -P "/content/drive/MyDrive/AI/models/antarctic-captions/" -R "index.html*" "https://the-eye.eu/public/AI/models/antarctic-captions/"
import argparse
import io
import numpy as np
import torch
import torch.nn as nn
import requests
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TVTFF

from CLIP import clip
from PIL import Image
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision.utils import make_grid

import model
import utils
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
# Helper functions
def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def load_image(img, preprocess):
    img = Image.open(fetch(img))
    return img, preprocess(img).unsqueeze(0).to(device)

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = TVTFF.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

def display_grid(imgs):
    reshaped = [TVTFF.to_tensor(x.resize((256, 256))) for x in imgs]
    show(make_grid(reshaped))
    
def clip_rescoring(args, net, candidates, x):
    textemb = net.perceiver.encode_text(
        clip.tokenize(candidates).to(args.device)).float()
    textemb /= textemb.norm(dim=-1, keepdim=True)
    similarity = (100.0 * x @ textemb.T).softmax(dim=-1)
    _, indices = similarity[0].topk(args.num_return_sequences)
    return [candidates[idx] for idx in indices[0]]

def loader(args):
    cache = []
    with open(args.textfile) as f:
        for line in f:
            cache.append(line.strip())
    cache_emb = np.load(args.embfile)
    net = utils.load_ckpt(args)
    net.cache = cache
    net.cache_emb = torch.tensor(cache_emb).to(args.device)
    preprocess = clip.load(args.clip_model, jit=False)[1]
    return net, preprocess
    
def caption_image(path, args, net, preprocess):
    print('in caption_image')
    captions = []
    img, mat = load_image(path, preprocess)
    table, x = utils.build_table(mat.to(device), 
                        perceiver=net.perceiver,
                        cache=net.cache,
                        cache_emb=net.cache_emb,
                        topk=args.topk,
                        return_images=True)
    table = net.tokenizer.encode(table[0], return_tensors='pt').to(device)
    out = net.model.generate(table,
                            do_sample=args.do_sample,
                            num_beams=args.num_beams,
                            temperature=args.temperature,
                            top_p=args.top_p,
                            num_return_sequences=args.num_return_sequences)
    candidates = []
    for seq in out:
        candidates.append(net.tokenizer.decode(seq, skip_special_tokens=True))
    captions = clip_rescoring(args, net, candidates, x[None,:])
    for c in captions[:args.display]:
        print(c)
    #display_grid([img])
    return captions
# Settings
antarctic_filedir='/content/drive/MyDrive/AI/models/antarctic-captions/the-eye.eu/public/AI/models/antarctic-captions'
antarctic_args = argparse.Namespace(
    ckpt=f'{antarctic_filedir}/-epoch=05-vloss=2.163.ckpt',
    textfile=f'{antarctic_filedir}/postcache.txt',
    embfile=f'{antarctic_filedir}/postcache.npy',
    clip_model='ViT-B/16',
    topk=10,
    num_return_sequences=1000,
    num_beams=1,
    temperature=1.0,
    top_p=1.0,
    display=antarctic_number_of_captions,
    do_sample=True,
    device=device
)
# Load checkpoint and preprocessor
antarctic_net, antarctic_preprocess = loader(antarctic_args)
%cd ..

antarctic_prompt_request_directory = '/content/drive/MyDrive/AI/antarctic_prompt/request/'

# Anartic request loop

In [None]:
import os
import time
import sys 
timeout_minutes=60 #@param {type:"integer"}
#how to count the time that spends on while loop
start_time = time.time()
while True:
  while len(os.listdir(antarctic_prompt_request_directory)) == 0:
      time.sleep(1)
      if (time.time() - start_time) > timeout_minutes*60:
          print('timeout')
          sys.exit()

  if len(os.listdir(antarctic_prompt_request_directory)) > 0:
    for file in os.listdir(antarctic_prompt_request_directory):
      if file.endswith('.png'):
        init_image = antarctic_prompt_request_directory + 'init_image.png'
        new_antarctic_prompts = caption_image(init_image, antarctic_args, antarctic_net, antarctic_preprocess)[:antarctic_number_of_captions]
        #create a text file in the antarctic_prompt_request_directory with the new antarctic prompts
        #join new_anarctic_prompts with a ; and print it
        print(';'.join(new_antarctic_prompts))
        with open(antarctic_prompt_request_directory + 'antarctic_prompts.txt', 'w') as f:
          for prompt in new_antarctic_prompts:
            f.write(prompt)
        os.remove(antarctic_prompt_request_directory + 'init_image.png')
        #testing
        start_time = time.time()