## Setup

In [None]:
import torchvision.transforms as tfs
from src.models.utils import download_checkpoint, load_model
from src.experiments import *
from src.experiments.utils import *
from src.visualization import *
from src.optimization import *
from IPython.display import Image 
from PIL import Image

checkpoint_path = download_checkpoint("https://drive.google.com/file/d/19m_SaRNEF7JXHjeyNu26AxgaEQXqpI00", "protopnet.pt")
model = load_model('protopnet', checkpoint_path, device="cpu")
img_size = 224

In [None]:
transforms_base = tfs.Compose([
    tfs.ToPILImage(),
    tfs.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
    tfs.RandomRotation((-1, 1)),
    tfs.RandomResizedCrop(size=224, scale=(0.99, 1)),
    tfs.ToTensor(),
])

## Image opt

In [None]:
ptypes = [(188,1)]

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
normalize = tfs.Normalize(mean, std)
transform = tfs.Compose([
    tfs.Resize(size=(img_size, img_size)),
    tfs.ToTensor(),
    normalize])

input_image = Image.open("data/train_cropped/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0002_180879.jpg")
input_tensor = transform(input_image)
display(tfs.ToPILImage()(input_tensor))

image = visualize_prototypes(model, ptypes, optimization_steps=100, input_tensor=input_tensor, transforms=transforms_base)

pilimg = tfs.ToPILImage()(image)
display(pilimg)

In [None]:
input_image = Image.open("data/train_cropped/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0002_180879.jpg")
input_tensor = transform(input_image)
display(tfs.ToPILImage()(input_tensor))

image = visualize_prototypes(model, ptypes, optimization_steps=100, input_tensor=input_tensor)

pilimg = tfs.ToPILImage()(image)
display(pilimg)

## Noise opt

In [None]:
ptypes = [(188,1)]

def before_optim_step(t):
    tt = torch.clamp(t, 0, 1)
    t.data = tfs.GaussianBlur(7, 2)(tt).data

size = (3, 224, 224)
input_tensor = torch.randn(size)

image = visualize_prototypes(model, ptypes, optimization_steps=100, input_tensor=input_tensor,
                             before_optim_step=before_optim_step, optimizer_kwargs={'lr': 0.2}, transforms=transforms_base)

pilimg = tfs.ToPILImage()(image)
display(pilimg)

In [None]:
ptypes = [(188,1)]

def before_optim_step(t):
    tt = torch.clamp(t, 0, 1)
    t.data = tfs.GaussianBlur(7, 2)(tt).data

size = (3, 224, 224)
input_tensor = torch.randn(size)

image = visualize_prototypes(model, ptypes, optimization_steps=100, input_tensor=input_tensor,
                             before_optim_step=before_optim_step, optimizer_kwargs={'lr': 0.2})

pilimg = tfs.ToPILImage()(image)
display(pilimg)

## Octaves

In [None]:
ptypes = [(188,1)]

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
normalize = tfs.Normalize(mean, std)
transform = tfs.Compose([
    tfs.Resize(size=(img_size, img_size)),
    tfs.ToTensor(),
    normalize])

def before_optim_step(t):
    tt = torch.clamp(t, 0, 1)
    t.data = tfs.GaussianBlur(7, 2)(tt).data

input_image = Image.open("data/train_cropped/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0002_180879.jpg")
input_tensor = transform(input_image)
display(tfs.ToPILImage()(input_tensor))

image = visualize_prototypes_octaves(model, ptypes, optimization_steps=1000, input_tensor=input_tensor,
                             before_optim_step=before_optim_step, optimizer_kwargs={'lr': 0.2})

pilimg = tfs.ToPILImage()(image)
display(pilimg)