In [None]:
import sys
import os
sys.path.append("open_flamingo")
directory_path = os.path.abspath(os.path.join('..'))
if directory_path not in sys.path:
    sys.path.append(directory_path)
from datasets import COCOFlickrDataset, ImageNetDataset

import os
import shutil
import time
import string
import random

import numpy as np
import open_clip
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from training.scheduler import cosine_lr
from torchvision import transforms
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
from train.pgd_train import pgd
from train.apgd_train import apgd_train as apgd
import wandb
from utils import init_wandb, AverageMeter
from sam_data import SamData
from open_flamingo.eval.models.utils import unwrap_model
from train.utils import str2bool
from CLIP_eval.eval_utils import load_clip_model

import argparse

In [None]:
torch.manual_seed(0)
np.random.seed(0)

In [None]:
import wandb

wandb.login()

In [None]:
main_device = 'cuda:0'
eps= 2/255
stepsize_adv= 2/255
batch_size=64
data_path="C:/CodesSpring24/Data/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC"
inner_loss='l2'
norm='Linf'
iterations_adv=10
stepsize_adv=1.
clean_weight=0.
momentum_sgd=0.9
template='std'
clip_model_name='ViT-L-14'
output_normalize=False
attack='apgd'


In [None]:

model_orig, _, image_processor = open_clip.create_model_and_transforms(
        clip_model_name, pretrained='openai'
    )

model, _, _ = load_clip_model(clip_model_name, 'openai')


In [None]:
class ClipVisionModel(torch.nn.Module):
    def __init__(self, model, args, normalize):
        super().__init__()
        self.model = model
        self.args = args
        self.normalize = normalize

    def forward(self, vision, output_normalize):
        embedding = self.model(self.normalize(vision))
        if output_normalize:
            embedding = F.normalize(embedding, dim=-1)
        return embedding

In [None]:
preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
normalize = image_processor.transforms[-1]
del image_processor
print(f'[preprocessor_without_normalize] {preprocessor_without_normalize}')
print(f'[normalize] {normalize}')

In [None]:
dataset = ImageNetDataset(
            root=data_path+ '/train',
            transform=preprocessor_without_normalize,
        ) 
dataset_eval = ImageNetDataset(
        root=data_path + '/val',
        transform=preprocessor_without_normalize,
    )

In [None]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
dataloader_eval = DataLoader(dataset_eval, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)

if template == 'std':
        template = 'This is a photo of a {}'
elif template == 'blurry':
        template = 'This is a blurry photo of a {}'
else:
    raise ValueError(f'Unknown template: {template}')
    
print(f'template: {template}')
texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()]
print("These are samples",texts[:10])
text_tokens = open_clip.tokenize(texts)

model_orig.to(main_device)


In [None]:
with torch.no_grad():
        embedding_text_labels_norm = []
        for el in (text_tokens[:500], text_tokens[500:]):
            # we need to split the text tokens into two batches because otherwise we run out of memory
            # note that we are accessing the model directly here, not the CustomModel wrapper
            # thus its always normalizing the text embeddings
            embedding_text_labels_norm.append(
                model_orig.encode_text(el.to(main_device), normalize=True).detach().cpu()
            )
        embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device)
        assert torch.allclose(
            F.normalize(embedding_text_labels_norm, dim=0),
            embedding_text_labels_norm
        )
        if clip_model_name == 'ViT-B-32':
            assert embedding_text_labels_norm.shape == (512, 1000), embedding_text_labels_norm.shape
        elif clip_model_name in ('ViT-L-14', 'ViT-L-14-336'):
            assert embedding_text_labels_norm.shape == (768, 1000), embedding_text_labels_norm.shape
        else:
            raise ValueError(f'Unknown model: {clip_model_name}')
#==========================================================================
model_orig.cpu()
model_orig = ClipVisionModel(model=model_orig.visual, args="args", normalize=normalize)

model = ClipVisionModel(model=model.visual, args="args", normalize=normalize)


    # set optimizer (all params have requires_grad=True)
params = unwrap_model(model).model.parameters() #unwrap model before saving it

In [None]:
def l2(out, targets, reduction='none'):

    assert out.shape == targets.shape, f'{out.shape} != {targets.shape}'
    assert out.shape[0] > 1
    # Compute the element-wise squared error
    squared_error_batch = F.mse_loss(out, targets, reduction='none')
    if reduction == 'mean':
        squared_error_batch = torch.mean(squared_error_batch.sum(dim=1))
    else:
        squared_error_batch = squared_error_batch.sum(dim=1)
        assert squared_error_batch.shape == (out.shape[0],), f'{squared_error_batch.shape} != {(out.shape[0],)}'
    return squared_error_batch

def ce(out, targets, reduction='mean'):
    # out = logits
    assert out.shape[0] == targets.shape[0], (out.shape, targets.shape)
    assert out.shape[0] > 1

    return F.cross_entropy(out, targets, reduction=reduction)

In [None]:
def compute_loss(loss_str, embedding, targets, embedding_orig, logit_scale,
                 embedding_text_labels_norm=None, reduction='mean'):
    if loss_str == 'l2':
        loss = l2(out=embedding, targets=embedding_orig, reduction=reduction)
    elif loss_str == 'ce':
        loss = ce(
            out=embedding @ (logit_scale * embedding_text_labels_norm),
            targets=targets,
            reduction=reduction
        )
    else:
        raise ValueError(f'loss {loss_str} not supported')
    return loss

In [None]:
class ComputeLossWrapper:
    def __init__(self, embedding_orig, embedding_text_labels_norm, reduction='mean', loss=None,
                 logit_scale=100.):
        self.embedding_orig = embedding_orig
        self.embedding_text_labels_norm = embedding_text_labels_norm
        self.reduction = reduction
        self.loss_str = loss
        self.logit_scale = logit_scale

    def __call__(self, embedding, targets):
        return compute_loss(
            loss_str=self.loss_str, embedding=embedding, targets=targets,
            embedding_orig=self.embedding_orig, logit_scale=self.logit_scale,
            embedding_text_labels_norm=self.embedding_text_labels_norm, reduction=self.reduction
            )

In [None]:
model_orig.to(main_device)
model.to(main_device)
model_orig.eval()
model.train()

# for x in locals().values():
#     if isinstance(x, str) and x in ['True', 'False']:
#         assert False, f'args contains a string that should be a bool: {x}'


for i, (data, targets) in enumerate(dataloader):
        # is_classification = isinstance(targets, torch.Tensor)
        data = data.to(main_device)
        n_samples = data.shape[0]
        # if is_classification:
        targets = targets.to(main_device)

        with torch.no_grad():
            embedding_orig = model_orig(vision=data, output_normalize=output_normalize)

        # loss for the attack
        model.eval()


        loss_inner_wrapper = ComputeLossWrapper(
            embedding_orig, embedding_text_labels_norm,
            reduction='none' if attack == 'apgd' else 'mean', loss=inner_loss,
            logit_scale=100.
            )
        
        data_adv = apgd(
                model=model,
                loss_fn=loss_inner_wrapper,
                x=data,
                y=targets,
                norm=norm,
                eps=eps,
                n_iter=iterations_adv,
                verbose=True
            )
        del loss_inner_wrapper
        print(data_adv)
        print(data_adv.size())