In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import os
import transformers
from accelerate import Accelerator

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
accelerator = Accelerator()
device = accelerator.device


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import ViTImageProcessor, ViTForImageClassification, ViTConfig
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

pretrained_name = 'google/vit-base-patch16-224'
config = ViTConfig.from_pretrained(pretrained_name)
processor = ViTImageProcessor.from_pretrained(pretrained_name)
pred_model = ViTForImageClassification.from_pretrained(pretrained_name)
pred_model.to(device)

inputs = processor(images=image, return_tensors="pt")
inputs.to(device)
outputs = pred_model(**inputs, output_hidden_states=True)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", pred_model.config.id2label[predicted_class_idx])

Predicted class: Egyptian cat


In [3]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_blocks=5, bottleneck_dim=64):
        super(MLP, self).__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.layers = nn.ModuleList()
        for _ in range(num_blocks):
            shortcut_layers = []
            shortcut_layers.append(nn.Linear(hidden_dim, bottleneck_dim))
            shortcut_layers.append(nn.Dropout())
            shortcut_layers.append(nn.ReLU())  # Using ReLU for simplicity; you can choose other activations as needed
            shortcut_layers.append(nn.Linear(bottleneck_dim, bottleneck_dim))
            shortcut_layers.append(nn.Dropout())
            shortcut_layers.append(nn.ReLU())
            shortcut_layers.append(nn.Linear(bottleneck_dim, hidden_dim))
            shortcut_layers.append(nn.Dropout())
            self.layers.append(nn.Sequential(*shortcut_layers))

        self.output_layer= nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        x = self.input_layer(x)
        for layer in self.layers:
            x = x + layer(x) # shortcut
        return self.output_layer(x)

In [4]:
config.patch_size

16

In [5]:
def pairwise_cosine_similarity(Q, K):
    attention_scores = torch.matmul(Q, K.transpose(-2, -1)) #[N, P, L]
    # denominator = torch.sqrt((Q**2).sum(-1).unsqueeze(-1) * (K**2).sum(-1).unsqueeze(-2))
    denominator = (K**2).sum(-1).unsqueeze(-2)
    attention_weights = attention_scores / (denominator + 1e-5)
    return attention_weights

class SimplifiedAttention(nn.Module):
    def __init__(self, embed_size):
        super(SimplifiedAttention, self).__init__()
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
    
    def forward(self, q, k, v, reduce=True):
        assert len(q.shape) in (2, 3), "The query tensor must be 2 or 3 dimensional."
        if len(q.shape) == 2:
            # Q = self.query(q).unsqueeze(1) # [N, P, d], P = 1
            Q = q.unsqueeze(1)
        else:
            # Q = self.query(q) # [N, P, d] , P: prediction length
            Q = q
        # K = self.key(k) # [N, L, d]
        K = k
        # V = self.value(v) # [N, L, d]
        V = v
        
        # Compute the attention scores [N, P, L]
        # attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(Q.size(-1), dtype=torch.float32))
        # attention_scores = torch.matmul(Q, K.transpose(-2, -1))
        attention_weights = pairwise_cosine_similarity(Q, K)
        
        # Apply softmax to get the attention weights
        # attention_weights = F.softmax(attention_scores, dim=-1) # [N, P, L]
        # attention_weights = F.normalize(attention_scores, p=2, dim=-1)
        
        # Compute the weighted sum of values using the attention weights
        if reduce:
            attention_outputs = torch.matmul(attention_weights, V) # [N, P, d]
        else:
            attention_outputs = torch.einsum('bij,bjk->bijk', attention_weights, V) # [N, P, L] x [N, L, d] --> [N, P, L, d]
            # attention_outputs = V # [N, L, d]

        return attention_outputs, attention_weights  # Return both weights and outputs

# import collections
# class SurrogatePatchEmbeddings(nn.Module):
#     """
#     This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into
#     patches of shape `(batch_size, seq_length, patch_height, patch_width)` to be consumed by a
#     surrogate model.
#     """

#     def __init__(self, config):
#         super().__init__()
#         image_size, patch_size = config.image_size, config.patch_size
#         num_channels, hidden_size = config.num_channels, config.hidden_size

#         image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
#         patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
#         num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
#         self.image_size = image_size
#         self.patch_size = patch_size
#         self.num_channels = num_channels
#         self.num_patches = num_patches

#         self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

#     def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
#         batch_size, num_channels, height, width = pixel_values.shape
#         if num_channels != self.num_channels:
#             raise ValueError(
#                 "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
#                 f" Expected {self.num_channels} but got {num_channels}."
#             )
#         if not interpolate_pos_encoding:
#             if height != self.image_size[0] or width != self.image_size[1]:
#                 raise ValueError(
#                     f"Input image size ({height}*{width}) doesn't match model"
#                     f" ({self.image_size[0]}*{self.image_size[1]})."
#                 )
#         embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
#         return embeddings

class SurrogateInterpretation(nn.Module):
    def __init__(self, pred_model, classifier_head, input_embed, hidden_size) -> None:
        """
        pred_model: prediction model
        classifier_head: last fully connected layer 
        """
        super().__init__()

        self.num_labels = config.num_labels
        self.pred_model = pred_model

        # Input embedding, it doesn't necessarily be an embedding per se. 
        # It used to convert the input to a form of list of token tensors.
        self.input_embed = input_embed

        # Classifier head
        self.classifier = classifier_head
        # Transform function to non-linearly transform patch embedding to the representation space
        self.transform_func = MLP(input_dim=hidden_size,
                                  hidden_dim=hidden_size,
                                  output_dim=hidden_size,
                                  num_blocks=5,
                                  bottleneck_dim=64)
        self.attention = SimplifiedAttention(embed_size=hidden_size)

        # freeze parameters of the prediction model.
        if True:
            self.freeze_params()
        
        self.sim_loss_func = nn.MSELoss()
        self.cls_loss_func = nn.CrossEntropyLoss()
        self.kl_loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)
        self.cossim_loss_func = nn.CosineSimilarity(dim=-1)
    
    def freeze_params(self,):
        for name, param in self.pred_model.named_parameters():
            param.requires_grad = False 
            # print(f"freezed {name}")
        for param in self.classifier.parameters():
            param.requires_grad = False   
        for param in self.input_embed.parameters():
            param.requires_grad = False   
        return 
    
    def compute_loss(self, pred_out, pseudo_label_out, pred, pseudo_label):
        # assert len(last_cls_hidden_state.shape) in (2, 3), "The last hidden state should be of shape [N, L, d] or [N, d]"
        # if len(last_cls_hidden_state.shape) == 2:
        #     last_cls_hidden_state = last_cls_hidden_state.unsqueeze(1) # convert to [N, 1, d]

        # assert len(pseudo_label.shape) in (2, 3), "The last hidden state should be of shape [N, L, d] or [N, d]"
        # if len(pseudo_label.shape) == 2:
        #     pseudo_label = pseudo_label.unsqueeze(1) # convert to [N, 1, d]
        
        # sim_loss = self.sim_loss_func(pred_out, pseudo_label_out)
        cls_loss = self.cls_loss_func(pred, pseudo_label)
        # kl_loss = self.kl_loss(F.log_softmax(pred_out, dim=-1), F.log_softmax(pseudo_label_out, dim=-1))
        cos_sim = - self.cossim_loss_func(pred_out, pseudo_label_out).mean()

        loss = cls_loss + 1 * cos_sim

        return {'loss':loss, 
                'cls_loss': cls_loss,
                'cos_sim': cos_sim}
        
    
    def forward(self, pixel_values, labels=None, reduce_attention=True):
        outputs = self.pred_model(pixel_values=pixel_values, output_hidden_states=True) 
        last_cls_hidden_state = outputs['hidden_states'][-1][:,0,:] # [N, d] the last hidden state of the cls token
        patch_embeddings = self.input_embed(pixel_values=pixel_values) # [N, L, d]
        
        
        patch_reprs = self.transform_func(patch_embeddings)
        attention_output, attention_weights = self.attention(
            last_cls_hidden_state,
            patch_reprs,
            patch_reprs,
            reduce=reduce_attention
        ) # attention_weight [N, P, L], attention_output [N, P, d], 

        pseudo_label_out = self.classifier(last_cls_hidden_state)
        pseudo_label = pseudo_label_out.argmax(-1)
        pseudo_label_out = pseudo_label_out.contiguous().view(-1, pseudo_label_out.shape[-1])
        pseudo_label = pseudo_label.contiguous().view( pseudo_label.shape[-1])
        # print(last_cls_hidden_state.shape)
        # print(pseudo_label)
        # TODO fix
        # attention_output = torch.sum(patch_reprs, dim=1, keepdim=True)
        # TODO fix

        # pred = self.classifier(attention_output) # [N, L, out]
        # pred = torch.softmax(pred, dim=-1) # [N, L, out]
        # pred = torch.matmul(attention_weights, pred)  # [N, P, out]
        pred = self.classifier(attention_output) # [N, P, out]

        pred = pred.contiguous().view(-1, pred.shape[-1])
        # print(pred.shape)

        loss_dict = self.compute_loss(pred, pseudo_label_out, pred, pseudo_label)
        # loss_dict = self.compute_loss(attention_output.contiguous().view(-1, attention_output.shape[-1]), 
        #                               last_cls_hidden_state.contiguous().view(-1, last_cls_hidden_state.shape[-1]), 
        #                               pred, 
        #                               pseudo_label)
        loss = loss_dict['loss']

        pred_labels = pred.argmax(-1).view(-1)
        correct = (pred_labels == pseudo_label).sum()
        accuracy = correct / len(pred_labels)

        if labels is not None:
            pred_accuracy = (pseudo_label == labels).sum() / len(labels)
            outputs['pred_acc'] = pred_accuracy

        outputs['patch_reprs'] = patch_reprs
        outputs['attention_output'] = attention_output
        outputs['attention_weights'] = attention_weights
        outputs['last_hidden_state'] = last_cls_hidden_state
        outputs['loss'] = loss
        outputs['cossim_loss'] = loss_dict['cos_sim']
        outputs['cls_loss'] = loss_dict['cls_loss']
        outputs['acc'] = accuracy
        
        return outputs

In [6]:
from torchvision.transforms import (CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor)

image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose(
    [
        RandomResizedCrop(size),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(mean=image_mean, std=image_std)
    ]
)

_val_transforms = Compose(
    [
        Resize(size),
        CenterCrop(size),
        ToTensor(),
        Normalize(mean=image_mean, std=image_std)
    ]
)

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert('RGB')) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert('RGB')) for image in examples['image']]
    return examples

In [7]:
from datasets import load_dataset
dataset = load_dataset("mrm8488/ImageNet1K-val")
dataset = dataset['train']
splits = dataset.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)

Repo card metadata block was not found. Setting CardData to empty.


In [8]:
from torch.utils.data import DataLoader

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example['label'] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


In [9]:
batch_size = 256
train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=256, shuffle=True)

batch = next(iter(train_dataloader))
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(k, v.shape)

pixel_values torch.Size([256, 3, 224, 224])
labels torch.Size([256])


In [10]:
# input_embed = pred_model.get_input_embeddings()
# classifier_head = pred_model.classifier
# hidden_size = config.hidden_size
# model = SurrogateInterpretation(pred_model=pred_model, classifier_head=classifier_head, input_embed=input_embed, hidden_size=hidden_size)
# model.to(device)
# # outputs = model(**inputs)

In [11]:
input_embed = pred_model.get_input_embeddings()
classifier_head = pred_model.classifier
hidden_size = config.hidden_size
model = SurrogateInterpretation(pred_model=pred_model, classifier_head=classifier_head, input_embed=input_embed, hidden_size=hidden_size)
model.to(device)
outputs = model(**inputs)
print("attention_output shape: ", outputs['attention_output'].shape)
print("attention_weights shape: ", outputs['attention_weights'].shape)
print("last_hidden_state shape: ", outputs['last_hidden_state'].shape)
print("patch_reprs shape: ", outputs['patch_reprs'].shape)

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

model.train()
for epoch in range(5):
    for idx, data in enumerate(train_dataloader):
        pixel_values = data['pixel_values'].to(device)
        label = data['labels'].to(device)
        outputs = model(pixel_values, label)
        loss = outputs['loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"loss: {loss.item()}, acc: {outputs['acc'].item()}, pred_acc: {outputs['pred_acc'].item()}"
              f"cossim_loss: {outputs['cossim_loss'].item()}, cls_loss: {outputs['cls_loss'].item()}")
        
        

attention_output shape:  torch.Size([1, 1, 768])
attention_weights shape:  torch.Size([1, 1, 196])
last_hidden_state shape:  torch.Size([1, 768])
patch_reprs shape:  torch.Size([1, 196, 768])
loss: 12.071322441101074, acc: 0.26171875, pred_acc: 0.765625cossim_loss: -0.24208229780197144, cls_loss: 12.31340503692627
loss: 124.10237121582031, acc: 0.04296875, pred_acc: 0.78515625cossim_loss: -0.17761749029159546, cls_loss: 124.27999114990234
loss: 242.25257873535156, acc: 0.0078125, pred_acc: 0.78515625cossim_loss: -0.1607733964920044, cls_loss: 242.41334533691406
loss: 104.09986114501953, acc: 0.03515625, pred_acc: 0.75cossim_loss: -0.22822889685630798, cls_loss: 104.32808685302734
loss: 55.381561279296875, acc: 0.06640625, pred_acc: 0.765625cossim_loss: -0.20862434804439545, cls_loss: 55.590187072753906
loss: 75.34963989257812, acc: 0.07421875, pred_acc: 0.75cossim_loss: -0.2037229835987091, cls_loss: 75.55335998535156
loss: 26.76494598388672, acc: 0.09375, pred_acc: 0.81640625cossim_lo

In [20]:
# for name, param in model.named_parameters():
#     if param.requires_grad == True:
#         print(name, param.requires_grad)
model_name = 'vit_sur_loss=cls+cos_att=cossim2.pt'
torch.save(model.state_dict(), f'model/{model_name}')

In [13]:
print("attention_output shape: ", outputs['attention_output'])
print("attention_weights:", outputs['attention_weights'])
print("last_hidden_state shape: ", outputs['last_hidden_state'])
logit = model.classifier(outputs['last_hidden_state'])
pred = logit.argmax(-1)
print('pred: ', pred)
int_logit = model.classifier(outputs['attention_output'])
interp = int_logit.argmax(-1)
print('labels: ', data['labels'])
print('interp: ', interp.reshape([-1]))

attention_output shape:  tensor([[[ -2.4139,   1.2525,  -5.8068,  ...,  -1.3713,   3.6165,  -0.0333]],

        [[  0.8080,  -2.6431,  -1.7622,  ...,   7.8847,  -0.8463,  10.1360]],

        [[  3.2632,  -9.8401,   4.9660,  ..., -11.6845,  -3.7304,  -4.1211]],

        ...,

        [[  0.3884,  -4.7305,  -2.3395,  ...,  -4.6104,  -0.1576,   2.3779]],

        [[  6.6678,  -0.7116,   0.7123,  ...,  -1.5447,  -0.6401,  -5.3296]],

        [[ -0.0722,  -1.1478,   5.5667,  ...,   2.7927,  -1.6272,  -2.9185]]],
       device='mps:0', grad_fn=<UnsafeViewBackward0>)
attention_weights: tensor([[[ 0.0062, -0.0446, -0.0248,  ..., -0.0851, -0.0127, -0.0037]],

        [[-0.0486,  0.0095, -0.0574,  ..., -0.0123, -0.0571, -0.0509]],

        [[ 0.0132, -0.0574,  0.0335,  ..., -0.0291, -0.0285,  0.0066]],

        ...,

        [[ 0.0173,  0.0006, -0.0141,  ..., -0.0051, -0.0056, -0.1184]],

        [[-0.0358,  0.0102, -0.0101,  ...,  0.0503, -0.0032,  0.0283]],

        [[-0.0095,  0.0219,  0.0039

In [14]:
outputs['attention_weights']

tensor([[[ 0.0062, -0.0446, -0.0248,  ..., -0.0851, -0.0127, -0.0037]],

        [[-0.0486,  0.0095, -0.0574,  ..., -0.0123, -0.0571, -0.0509]],

        [[ 0.0132, -0.0574,  0.0335,  ..., -0.0291, -0.0285,  0.0066]],

        ...,

        [[ 0.0173,  0.0006, -0.0141,  ..., -0.0051, -0.0056, -0.1184]],

        [[-0.0358,  0.0102, -0.0101,  ...,  0.0503, -0.0032,  0.0283]],

        [[-0.0095,  0.0219,  0.0039,  ..., -0.0193, -0.0322,  0.0029]]],
       device='mps:0', grad_fn=<DivBackward0>)

In [15]:
from transformers import TrainingArguments, Trainer
metric_name = "accuracy"

args = TrainingArguments(
    f"sur_model01",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False
)

In [16]:
from sklearn.metrics import accuracy_score
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))


In [17]:
import torch

trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [18]:
trainer.train()


  4%|▎         | 500/13500 [01:45<43:35,  4.97it/s]

{'loss': -0.0818, 'grad_norm': 0.0, 'learning_rate': 1.925925925925926e-05, 'epoch': 0.11}


  7%|▋         | 1000/13500 [03:28<45:42,  4.56it/s]

{'loss': -0.0758, 'grad_norm': 0.0, 'learning_rate': 1.851851851851852e-05, 'epoch': 0.22}


 11%|█         | 1500/13500 [05:12<42:13,  4.74it/s]

{'loss': -0.0668, 'grad_norm': 0.0, 'learning_rate': 1.7777777777777777e-05, 'epoch': 0.33}


 15%|█▍        | 2000/13500 [06:57<40:17,  4.76it/s]

{'loss': -0.0879, 'grad_norm': 0.0, 'learning_rate': 1.7037037037037038e-05, 'epoch': 0.44}


 19%|█▊        | 2500/13500 [08:41<37:57,  4.83it/s]

{'loss': -0.0409, 'grad_norm': 0.0, 'learning_rate': 1.6296296296296297e-05, 'epoch': 0.56}


 22%|██▏       | 3000/13500 [10:25<36:15,  4.83it/s]

{'loss': -0.0952, 'grad_norm': 0.0, 'learning_rate': 1.555555555555556e-05, 'epoch': 0.67}


 26%|██▌       | 3501/13500 [12:10<33:33,  4.97it/s]

{'loss': -0.1246, 'grad_norm': 0.0, 'learning_rate': 1.4814814814814815e-05, 'epoch': 0.78}


 30%|██▉       | 4000/13500 [13:54<32:42,  4.84it/s]

{'loss': -0.038, 'grad_norm': 0.0, 'learning_rate': 1.4074074074074075e-05, 'epoch': 0.89}


 33%|███▎      | 4500/13500 [15:38<29:02,  5.17it/s]

{'loss': -0.0617, 'grad_norm': 0.0, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.0}




RuntimeError: MPS backend out of memory (MPS allocated: 35.93 GB, other allocations: 2.95 MB, max allowed: 36.27 GB). Tried to allocate 882.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
inputs = processor(images=dataset['train'][:512]['image'], return_tensors="pt")
inputs.to(device)
outputs = model(**inputs, output_hidden_states=True)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

In [None]:
dataset['train'][:512]['image']