In [1]:
from os.path import join, isfile, isdir
from os import listdir
from multiprocessing import Pool
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
from shutil import copy
import torch.nn as nn
import requests
import matplotlib.pyplot as plt
import urllib.request
from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection, AutoTokenizer
import torch
# import ipywidgets as widgets
from IPython.display import display, clear_output
import io

from torch.utils.data.dataset import Dataset

import os
import torch
import fnmatch
import numpy as np
import pandas as pd
import pdb
import random
from loguru import logger

%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def download_with_progress_bar(url, output_path):
    with urllib.request.urlopen(url) as response:
        total_size = int(response.headers.get('Content-Length', 0))
        block_size = 1024  # 1 Kibibyte
        progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
        with open(output_path, 'wb') as file:
            while True:
                data = response.read(block_size)
                if not data:
                    break
                file.write(data)
                progress_bar.update(len(data))

        progress_bar.close()

    print(f"File downloaded and saved to {output_path}")

In [3]:
class GRAFT(nn.Module):
    def __init__(self, CLIP_version="openai/clip-vit-base-patch16", temp=False, bias_projector=True):
        super().__init__()
        # satellite image backbone
        self.satellite_image_backbone = CLIPVisionModelWithProjection.from_pretrained(CLIP_version)
        self.patch_size = self.satellite_image_backbone.config.patch_size

        self.projector = nn.Sequential(
            nn.LayerNorm(self.satellite_image_backbone.config.hidden_size, eps=self.satellite_image_backbone.config.layer_norm_eps),
            nn.Linear(self.satellite_image_backbone.config.hidden_size, self.satellite_image_backbone.config.projection_dim, bias=bias_projector),
        )
        self.patch_size = self.satellite_image_backbone.config.patch_size
        self.norm_dim = -1

        self.temp = temp
        if temp:
            self.register_buffer("logit_scale", torch.ones([]) * (1 / 0.07))

    def forward(self, image_tensor):
        # Extract features from satellite images
        # B x 197 x 768 for VIT-B/16
        hidden_state = self.satellite_image_backbone(image_tensor).last_hidden_state
        # B x 197 x 512
        satellite_image_features = F.normalize(self.projector(hidden_state), dim=self.norm_dim)
        # get the satellite image features
        return satellite_image_features

    def forward_features(self, image_tensor):
        # Extract features from satellite images
        # B x 512 for VIT-B/16
        embed = self.satellite_image_backbone(image_tensor).image_embeds
        # B x 512
        satellite_image_features = F.normalize(embed)
        return satellite_image_features

In [4]:
model_dir = 'https://graft.cs.cornell.edu/static/models/<PLACEHOLDER>/'

In [5]:
device="cuda"
model = GRAFT(temp=True, bias_projector=False).to(device)
ckpt_url = model_dir+'/graft_sentinel.ckpt'
output_path = "checkpoints/graft/sentinel_model.ckpt"
ckpt_file = download_with_progress_bar(ckpt_url, output_path)
transform = transforms.Compose([transforms.Resize((224, 224)),
  # transforms.ToTensor(),
  transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])

ckpt_path = "checkpoints/graft/sentinel_model.ckpt"
sd = torch.load(ckpt_path)
model.load_state_dict(sd['state_dict'], strict=False)
textmodel = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch16").eval().to(device)
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch16")

100%|██████████| 1.03G/1.03G [00:09<00:00, 104MiB/s] 


File downloaded and saved to checkpoints/graft/sentinel_model.ckpt


In [6]:
class ClassificationDataset(Dataset):
    """
    This file is directly modified from https://pytorch.org/docs/stable/torchvision/datasets.html
    """
    def __init__(self, dataset_name:str="eurosat", split='test'):
        self.split = split

        # Read the data file
        df = pd.read_csv(f"../data/{dataset_name}_data.csv")
        df = df[df["split"]=="test"]

        with open(f"../data/{dataset_name}_metadata.npy", "rb") as f:
            self.label_to_class = np.load(f, allow_pickle=True)[()]["classes"]  # maps from idx (0-9) to actual string
        self.fps = df["fp"].values
        self.labels = df["label"].values
        print(f"label and class mapping: \n{df[['class_name','label']].value_counts().sort_index()}")
        self.num_outputs = len(self.label_to_class)
      
        self.data_len = len(self.fps)

        self.transforms_list = [transforms.ToTensor()]

    def __getitem__(self, index):
        fp = self.fps[index]
        label = self.labels[index]
        image = Image.open(fp)

        data_transforms = transforms.Compose(self.transforms_list)
        image = data_transforms(image)  # output of transforms: (3, s, s)

        return (
            image,    # shape: (3, s, s)
            label,
        )


    def __len__(self):
        return self.data_len



In [7]:

test_dataset1 = ClassificationDataset(dataset_name="eurosat", split="test")
print("transform from dataset:", test_dataset1.transforms_list)
test_loader = torch.utils.data.DataLoader(
        test_dataset1,
        batch_size=16,
        shuffle=False,
        num_workers=1,
        pin_memory=True,
        sampler=None,
        drop_last=False,
)
test_dataset = iter(test_loader)
label_to_class = test_dataset1.label_to_class
print(label_to_class)

label and class mapping: 
class_name            label
AnnualCrop            0        566
Forest                1        602
HerbaceousVegetation  2        593
Highway               3        515
Industrial            4        485
Pasture               5        390
PermanentCrop         6        486
Residential           7        628
River                 8        504
SeaLake               9        596
Name: count, dtype: int64
transform from dataset: [ToTensor()]
['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']


In [8]:
map_clean_text ={
    'AnnualCrop': "annual crop",
    'Forest': "forest", 
    'HerbaceousVegetation': "herbaceous vegetation", 
    'Highway': 'highway', 
    'Industrial': 'industrial', 
    'Pasture': 'pasture', 
    'PermanentCrop': 'permanent crop', 
    'Residential': 'residential', 
    'River': 'river', 
    'SeaLake': 'sea lake',
}
label_to_class = [f"A photo of a {map_clean_text[k]}" for k in label_to_class]
# label_to_class = [f"A centered satellite photo of a {map_clean_text[k]}" for k in label_to_class]
print(label_to_class)
textsenc = tokenizer(label_to_class, padding=True, return_tensors="pt").to(device)
text_features = F.normalize(textmodel(**textsenc).text_embeds, dim=-1)
text_features.shape

['A photo of a annual crop', 'A photo of a forest', 'A photo of a herbaceous vegetation', 'A photo of a highway', 'A photo of a industrial', 'A photo of a pasture', 'A photo of a permanent crop', 'A photo of a residential', 'A photo of a river', 'A photo of a sea lake']


torch.Size([10, 512])

In [9]:
test_batch = len(test_loader)
num_correct, num_total = 0,0
for k in tqdm(range(test_batch)):
  imgs, labels = next(test_dataset)
  with torch.no_grad():
    tr_image = transform(imgs).to(device)
    image_feature = model.forward_features(tr_image)
    texts = label_to_class
  with torch.no_grad():
    textsenc = tokenizer(texts, padding=True, return_tensors="pt").to(device) #tokenize
    class_embeddings = F.normalize(textmodel(**textsenc).text_embeds, dim=-1) #embed with text encoder
  classlogits = image_feature.cpu().numpy() @ class_embeddings.cpu().numpy().T
  
  preds = np.argmax(classlogits, axis=1)
  labels = labels.detach().numpy()
  num_cor = np.sum(preds==labels)
  num_correct += num_cor
  num_total += len(labels)
print(f"Accuracy: {num_correct/num_total}")
print(f"num_total: {num_total}")

100%|██████████| 336/336 [00:45<00:00,  7.32it/s]

Accuracy: 0.27642124883504193
num_total: 5365





In [10]:
def zero_shot_classification(image, classes=None):
  with torch.no_grad():
    tr_image = transform(image).unsqueeze(0).to(device)
    image_feature = model.forward_features(tr_image)
  if classes is not None:
    texts = classes
  else:
    texts = ["tennis courts", "parking lot", "farmland", "lake", "park", "powerlines", "University Campus", "Beach", "Freeway"]
  with torch.no_grad():
    textsenc = tokenizer(texts, padding=True, return_tensors="pt").to(device) #tokenize
    class_embeddings = F.normalize(textmodel(**textsenc).text_embeds, dim=-1) #embed with text encoder
  classlogits = image_feature.cpu().numpy() @ class_embeddings.cpu().numpy().T
  fig = plt.figure(figsize=(10, 5))
  plt.subplot(1, 2, 1)
  plt.title("Input Image")
  plt.imshow(image)
  plt.axis('Off')
  plt.subplot(1, 2, 2)
  plt.ylabel('Class matching score')
  plt.xlabel('Classes')
  plt.title("Graft best prediction: '{}'".format(texts[np.argmax(classlogits[0])]))
  plt.bar(range(len(classlogits[0])), classlogits[0])
  plt.xticks(range(len(texts)), texts, rotation=90)
  plt.show()
  return classlogits