In [1]:
import torch
import clip
import numpy as np
import matplotlib.pyplot as plt
import os
import math
from torchvision.datasets import CIFAR10
from torchvision.transforms import *
from tqdm.notebook import tqdm

In [2]:
from torch.utils.data.dataset import Subset
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC

def _convert_image_to_rgb(image):
    return image.convert("RGB")


def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

transforms = _transform(224)

idxs = np.load('cifar1098_idxs.npy').astype('int')
indices = []
for i in range(len(idxs)):
  if idxs[i]:
    indices.append(i)
# print(idxs)
# print(indices)
val = CIFAR10(root='./data', train=True, download=True, transform=transforms)
val = Subset(val, indices)
test = CIFAR10(root='./data', train=False, download=True, transform=transforms)

valloader = torch.utils.data.DataLoader(val,
                                        batch_size=128,
                                        shuffle=False,
                                        num_workers=2,
                                        drop_last=False)
testloader = torch.utils.data.DataLoader(test,
                                         batch_size=128,
                                         shuffle=False,
                                         num_workers=2,
                                         drop_last=False)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
def validate(model):
  preds = []
  labels = []
  for x, y in tqdm(valloader):
    x = x.to(device)
    y = y.to(device)
    preds.append(model(x).argmax(dim=1))
    labels.append(y)
  return torch.mean((torch.cat(preds) == torch.cat(labels)).float()).item()

def test(model):
  preds = []
  labels = []
  for x, y in tqdm(testloader):
    x = x.to(device)
    y = y.to(device)
    preds.append(model(x).argmax(dim=1))
    labels.append(y)
  return torch.mean((torch.cat(preds) == torch.cat(labels)).float()).item()

In [4]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, model, feature_dim, num_classes, normalize=False, initial_weights=None):
        super(ModelWrapper, self).__init__()
        self.model = model
        self.classification_head = torch.nn.Linear(feature_dim, num_classes)
        self.normalize = normalize
        if not self.normalize:
            print('normalize skipped.')

        if initial_weights is not None and type(initial_weights) == tuple:
            print('tuple.')
            w, b = initial_weights
            self.classification_head.weight = torch.nn.Parameter(w.clone())
            self.classification_head.bias = torch.nn.Parameter(b.clone())
        else:
            if initial_weights is None:
                initial_weights = torch.zeros_like(self.classification_head.weight)
                torch.nn.init.kaiming_uniform_(initial_weights, a=math.sqrt(5))
            self.classification_head.weight = torch.nn.Parameter(initial_weights.clone())
            # Note: modified. Initial bug in forgetting to zero bias.
            self.classification_head.bias = torch.nn.Parameter(torch.zeros_like(self.classification_head.bias))

        # Note: modified. Get rid of the language part.
        delattr(self.model, 'transformer')

    def forward(self, images):
        features = self.model.encode_image(images).float()
        if self.normalize:
            features = features / features.norm(dim=-1, keepdim=True)
        logits = self.classification_head(features)
        return logits

In [5]:
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
state_dicts = []

for f in sorted(os.listdir()):
  if f[-2:] == 'pt':
    print(f'Loading {f}')
    state_dicts.append(torch.load(f, map_location=device))
    break

Loading checkpoint_10.1.pt


In [6]:
def get_model(state_dicts, alphal):
  model, _ = clip.load('ViT-B/32')
  feature_dim = state_dicts[0]['classification_head.weight'].shape[1]
  num_classes = state_dicts[0]['classification_head.weight'].shape[0]
  normalize = True
  model = ModelWrapper(model, feature_dim, num_classes, normalize)
  sd = {k : state_dicts[0][k].clone() * alphal[0] for k in state_dicts[0].keys()}
  for i in range(1, len(state_dicts)):
      for k in state_dicts[i].keys():
          sd[k] = sd[k] + state_dicts[i][k].clone() * alphal[i]
  model.load_state_dict(sd)
  model = model.to(device)
  return model

In [7]:
def add_random_noise_to_model(model, noise_factor=0.01):
    with torch.no_grad():
        for param in model.parameters():
            # Check if the parameter is learnable (has gradients)
            if param.requires_grad:
                # Generate random noise of the same shape as the parameter
                random_noise = torch.randn_like(param) * noise_factor
                # Add random noise to the parameter
                param.add_(random_noise)

In [19]:
state_dicts_2 = []
val_results = []
test_results = []
import copy
model, _ = clip.load('ViT-B/32')
feature_dim = state_dicts[0]['classification_head.weight'].shape[1]
num_classes = state_dicts[0]['classification_head.weight'].shape[0]
normalize = True
model = ModelWrapper(model, feature_dim, num_classes, normalize)
state_dicts_2.append(torch.load('checkpoint_10.1.pt', map_location=device))
model.load_state_dict(state_dicts_2[0])

model = model.to(device)
real_model = copy.deepcopy(model)
val_results.append(validate(model))
test_results.append(test(model))
print(val_results[-1], test_results[-1])
for j in range(5):
  noise_factor = 2e-3  # You can adjust this value
  add_random_noise_to_model(model, noise_factor)
  state_dicts_2.append(model.state_dict())
  
  val_results.append(validate(model))
  test_results.append(test(model))
  print(val_results[-1], test_results[-1])
  model = copy.deepcopy(real_model)

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

0.9825999736785889 0.9769999980926514


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

0.9715999960899353 0.9645999670028687


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

0.9717999696731567 0.9670999646186829


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

0.9684000015258789 0.9627999663352966


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

0.9703999757766724 0.9644999504089355


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

0.9699999690055847 0.965999960899353


In [14]:
print(val_results, test_results)

[0.9825999736785889, 0.9815999865531921, 0.9824000000953674, 0.9819999933242798, 0.9817999601364136, 0.9811999797821045] [0.9769999980926514, 0.9765999913215637, 0.976699948310852, 0.9751999974250793, 0.976099967956543, 0.9767999649047852]


In [15]:
print(len(state_dicts_2))

6


In [16]:
alphal = [1 / len(state_dicts_2) for i in range(len(state_dicts_2))]
model = get_model(state_dicts_2, alphal)
test_results.append(test(model))
print(test_results[-1])

  0%|          | 0/79 [00:00<?, ?it/s]

0.9765999913215637
