## Setup

In [None]:
import torchvision.transforms as tfs
from src.models.utils import download_checkpoint, load_model
from src.experiments import *
from src.experiments.utils import *
from src.visualization import *
from src.optimization import *
from IPython.display import Image 
from PIL import Image
from src.optimization.optimizers import NormalizedOptimizer

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
checkpoint_path = download_checkpoint("https://drive.google.com/file/d/19m_SaRNEF7JXHjeyNu26AxgaEQXqpI00", "protopnet.pt")

model = load_model('protopnet', checkpoint_path, device=device)
img_size = 224

In [None]:
ptypes = [(188,1)]

size = (3, 224, 224)
normalize = tfs.Normalize(mean=(0.485, 0.456, 0.406), 
                          std=(0.229, 0.224, 0.225))
transform = tfs.Compose([
    tfs.Resize(size=(img_size, img_size)),
    tfs.ToTensor(),
    normalize]
    )
invTrans = tfs.Compose([
    tfs.Normalize(mean=[ 0., 0., 0. ], 
                  std=[ 1/0.229, 1/0.224, 1/0.225 ]),
    tfs.Normalize(mean=[ -0.485, -0.456, -0.406 ],
                 std=[ 1., 1., 1. ]),
    ])

def before_optim_step(t):
    tt = torch.clamp(t, 0, 1)
    t.data = tfs.GaussianBlur(7, 2)(t).data

## Testing aggregation methods
Mean of similarities with exponent=1 seems to be the best. Red head seems to appear in other places than just initial bird-head. Overall, differences are not very significant.

In [None]:
input_image = Image.open("data/train_cropped/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0002_180879.jpg")
display(input_image)

In [None]:
input_tensor = transform(input_image)
for metric in ["distance", "similarity"]:
    for agg_fn in ["mean", "mean_log"]:
        for exponent in [1., 2.]:
            print(f"metric: {metric} | agg_fn: {agg_fn} | exponent: {exponent}")
            loss_agg_fn = AggregationFn(metric=metric, agg_fn=agg_fn, exponent=exponent).to(device)
            image = visualize_prototypes(model, ptypes, loss_agg_fn=loss_agg_fn, optimization_steps=100, input_tensor=input_tensor,
                             before_optim_step=before_optim_step, optimizer_kwargs={'lr': 0.2})
            image = invTrans(image)
            pilimg = tfs.ToPILImage()(image)
            display(pilimg)


In [None]:
loss_agg_fn = AggregationFn(metric="similarity").to(device)

In [None]:
def vis_box_get_names(bird):
    crop_images(bird) 
    directory = f"data/train_cropped/{bird}"
    count = 0
    names = []
    p_num = int(bird.split(".")[0])-1
    for file in os.listdir(directory):
            filename = os.fsdecode(file)
            if filename.endswith(".jpg"):
                name = f"{bird}/{filename}"
                names += [name]
                for i in range(10):
                    print(f"Bird {count} | Prototype {i}")
                    img = visualize_real_prototype(model, name, p_num, i)
                    display(img)
                count += 1
                if count == 8: break
    return names

In [None]:
def vis_noise(ptypes):
    input_tensor = torch.randn(size)

    image = visualize_prototypes(model, ptypes, loss_agg_fn=loss_agg_fn, optimization_steps=200, input_tensor=input_tensor,
                                 before_optim_step=before_optim_step, optimizer_kwargs={'lr': 0.2}, print_interval=1000, display_interval=500)
    image = invTrans(image)
    pilimg = tfs.ToPILImage()(image)
    display(pilimg)

def vis_imgs(ptypes, names):
    for name in names:
        input_image = Image.open(f"data/train_cropped/{name}")
        display(input_image)
        input_tensor = transform(input_image)  
        image = visualize_prototypes(model, ptypes, loss_agg_fn=loss_agg_fn, optimization_steps=200, input_tensor=input_tensor,
                                     before_optim_step=before_optim_step,  optimizer_kwargs={'lr': 0.2}, print_interval=1000, display_interval=500)
        image = invTrans(image)
        pilimg = tfs.ToPILImage()(image)
        display(pilimg)

# Red bellied Woodpecker
## Only two kinds of prototypes

In [None]:
bird = "189.Red_bellied_Woodpecker"
names = vis_box_get_names(bird)

## Red head

In [None]:
ptypes = [(188,1)] # red head
vis_noise(ptypes)
vis_imgs(ptypes, names)

## Black-white feathers

In [None]:
ptypes = [(188,0)] # black-white feathers
vis_noise(ptypes)
vis_imgs(ptypes, names)

# Yellow breasted Chat
## One yellowish prototype

In [None]:
bird = "020.Yellow_breasted_Chat"
names = vis_box_get_names(bird)

In [None]:
ptypes = [(19,0)]
vis_noise(ptypes)
vis_imgs(ptypes, names)

## Not-so-great blueish prototypes

In [None]:
bird = "074.Florida_Jay"
names = vis_box_get_names(bird)

In [None]:
ptypes = [(73,2)]
vis_noise(ptypes)
vis_imgs(ptypes, names)