In [1]:
import torch 
from transformers import ViTImageProcessor, ViTForImageClassification, ViTModel, ViTConfig
from PIL import Image
import requests
import matplotlib.pyplot as plt

import numpy as np
import cv2
from datasets import load_dataset,load_metric
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

from accelerate import Accelerator

accelerator = Accelerator()
device = accelerator.device


  from .autonotebook import tqdm as notebook_tqdm
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [2]:
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# url = "http://farm3.staticflickr.com/2066/1798910782_5536af8767_z.jpg"
# url = "http://farm1.staticflickr.com/184/399924547_98e6cef97a_z.jpg"
url = "http://farm1.staticflickr.com/128/318959350_1a39aae18c_z.jpg"
image = Image.open(requests.get(url, stream=True).raw)

pretrained_name = 'google/vit-base-patch16-224'
# pretrained_name = 'vit-base-patch16-224-finetuned-imageneteval'
# pretrained_name = 'openai/clip-vit-base-patch32'
config = ViTConfig.from_pretrained(pretrained_name)
processor = ViTImageProcessor.from_pretrained(pretrained_name)
# get mean and std to unnormalize the processed images
mean, std = processor.image_mean, processor.image_std

pred_model = ViTForImageClassification.from_pretrained(pretrained_name)
pred_model.to(device)
# set to eval mode
pred_model.eval()

with torch.no_grad():
    inputs = processor(images=image, return_tensors="pt")
    inputs.to(device)
    outputs = pred_model(**inputs, output_hidden_states=True)
    logits = outputs.logits
    # model predicts one of the 1000 ImageNet classes
    predicted_class_idx = logits.argmax(-1).item()
    print("Predicted class:", pred_model.config.id2label[predicted_class_idx])

Predicted class: golden retriever


In [3]:
from maskgen.models.vision_maskgen_model7 import MaskGeneratingModel

mask_gen_model = MaskGeneratingModel(pred_model, hidden_size=config.hidden_size, num_classes=config.num_labels)
mask_gen_model.to(device)
print()




# training

In [4]:
from torch.utils.data import DataLoader
def load_data(seed=42): 
    dataset = load_dataset("mrm8488/ImageNet1K-val")
    dataset = dataset['train']
    splits = dataset.train_test_split(test_size=0.1, seed=seed)
    test_ds = splits['test']
    splits = splits['train'].train_test_split(test_size=0.1, seed=seed)
    train_ds = splits['train']
    val_ds = splits['test']
    return train_ds, val_ds, test_ds

train_ds, val_ds, test_ds = load_data()
dataset = load_dataset("mrm8488/ImageNet1K-val")['train']

normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
if "height" in processor.size:
    size = (processor.size["height"], processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in processor.size:
    size = processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = processor.size.get("longest_edge")

transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

def preprocess(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_ds.set_transform(preprocess)
test_ds.set_transform(preprocess)
dataset.set_transform(preprocess)

Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.


# training params

In [5]:
batch_size = 256
num_workers = 8
# train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers)
# train_dataloader = DataLoader(test_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers)
train_dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers)


n_steps = 2
n_samples = 5

params_to_optimize = [name for name, param in mask_gen_model.named_parameters() if param.requires_grad]
print("params to be optimized: ")
print(params_to_optimize)

params to be optimized: 
['similarity_measure.logit_scale', 'similarity_measure.pred_map.input_layer.weight', 'similarity_measure.pred_map.input_layer.bias', 'similarity_measure.pred_map.layers.0.0.weight', 'similarity_measure.pred_map.layers.0.0.bias', 'similarity_measure.pred_map.layers.0.3.weight', 'similarity_measure.pred_map.layers.0.3.bias', 'similarity_measure.pred_map.layers.0.6.weight', 'similarity_measure.pred_map.layers.0.6.bias', 'similarity_measure.pred_map.layers.1.0.weight', 'similarity_measure.pred_map.layers.1.0.bias', 'similarity_measure.pred_map.layers.1.3.weight', 'similarity_measure.pred_map.layers.1.3.bias', 'similarity_measure.pred_map.layers.1.6.weight', 'similarity_measure.pred_map.layers.1.6.bias', 'similarity_measure.pred_map.output_layer.weight', 'similarity_measure.pred_map.output_layer.bias', 'similarity_measure.explain_map.input_layer.weight', 'similarity_measure.explain_map.input_layer.bias', 'similarity_measure.explain_map.layers.0.0.weight', 'similarit

In [6]:
from tqdm import tqdm

params_to_optimize = [param for param in mask_gen_model.parameters() if param.requires_grad]
# optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.9)


print()

for epoch in range(10):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        pixel_values = data['pixel_values'].to(device)
        loss_dict = mask_gen_model.train_one_batch(pixel_values, optimizer=optimizer, n_steps=n_steps, n_samples=n_samples)
        # scheduler.step()
        pbar.set_description(f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss_dict['loss'].item():.4f}, " 
                             f"Reward Loss = {loss_dict['reward_loss'].item():.4f}, "
                            #  f"Regret Loss = {loss_dict['regret_loss'].item():.4f}, "
                             f"Mask Loss = {loss_dict['mask_loss'].item():.4f} "
                            #  f"alt_mask_loss = {loss_dict['alt_mask_loss'].item():.4f} "
                             f"mask_mean = {loss_dict['mask_mean'].item():.4f} "
                             f"prob_mean = {loss_dict['prob_mean'].item():.4f} "
                             )
        if idx % 10 == 0:
            print()
        if (idx) % 10 == 0:
            
            torch.save(mask_gen_model.state_dict(), f'mask_gen_model/mask_gen_model_{epoch}_{idx}.pth') 



torch.save(mask_gen_model.state_dict(), f'mask_gen_model/mask_gen_model_final_{epoch}_{idx}.pth') 





Epoch 1, Step 1: Loss = 0.2051, Reward Loss = -0.0426, Mask Loss = 0.4954 mask_mean = 0.4954 prob_mean = 0.6115 :   0%|          | 0/196 [00:23<?, ?it/s]




Epoch 1, Step 11: Loss = 0.0336, Reward Loss = -0.1695, Mask Loss = 0.4062 mask_mean = 0.4062 prob_mean = 0.4979 :   5%|▌         | 10/196 [04:14<1:09:47, 22.51s/it]




Epoch 1, Step 21: Loss = -0.0807, Reward Loss = -0.3014, Mask Loss = 0.4413 mask_mean = 0.4413 prob_mean = 0.5502 :  10%|█         | 20/196 [08:03<1:05:48, 22.44s/it]




Epoch 1, Step 31: Loss = -0.1962, Reward Loss = -0.4218, Mask Loss = 0.4511 mask_mean = 0.4511 prob_mean = 0.6072 :  15%|█▌        | 30/196 [11:52<1:02:04, 22.43s/it]




Epoch 1, Step 41: Loss = -0.0537, Reward Loss = -0.3001, Mask Loss = 0.4927 mask_mean = 0.4927 prob_mean = 0.5966 :  20%|██        | 40/196 [15:40<58:08, 22.36s/it]  




Epoch 1, Step 51: Loss = -0.3813, Reward Loss = -0.6222, Mask Loss = 0.4818 mask_mean = 0.4818 prob_mean = 0.6390 :  26%|██▌       | 50/196 [19:26<54:21, 22.34s/it]  




Epoch 1, Step 61: Loss = -0.1838, Reward Loss = -0.4177, Mask Loss = 0.4678 mask_mean = 0.4678 prob_mean = 0.5431 :  31%|███       | 60/196 [23:14<50:36, 22.33s/it]




Epoch 1, Step 71: Loss = -0.2946, Reward Loss = -0.5313, Mask Loss = 0.4734 mask_mean = 0.4734 prob_mean = 0.5475 :  36%|███▌      | 70/196 [27:01<46:55, 22.35s/it]




Epoch 1, Step 81: Loss = -0.2724, Reward Loss = -0.5101, Mask Loss = 0.4754 mask_mean = 0.4754 prob_mean = 0.5703 :  41%|████      | 80/196 [30:48<43:11, 22.34s/it]




Epoch 1, Step 91: Loss = -0.2761, Reward Loss = -0.5117, Mask Loss = 0.4712 mask_mean = 0.4712 prob_mean = 0.5305 :  46%|████▌     | 90/196 [34:35<39:27, 22.34s/it]




Epoch 1, Step 101: Loss = -0.3075, Reward Loss = -0.5340, Mask Loss = 0.4530 mask_mean = 0.4530 prob_mean = 0.5636 :  51%|█████     | 100/196 [38:22<35:47, 22.37s/it]




Epoch 1, Step 111: Loss = -0.3106, Reward Loss = -0.5369, Mask Loss = 0.4524 mask_mean = 0.4524 prob_mean = 0.5330 :  56%|█████▌    | 110/196 [42:09<32:05, 22.39s/it]




Epoch 1, Step 121: Loss = -0.3354, Reward Loss = -0.5581, Mask Loss = 0.4454 mask_mean = 0.4454 prob_mean = 0.5624 :  61%|██████    | 120/196 [45:57<28:23, 22.42s/it]




Epoch 1, Step 131: Loss = -0.2877, Reward Loss = -0.5025, Mask Loss = 0.4295 mask_mean = 0.4295 prob_mean = 0.5308 :  66%|██████▋   | 130/196 [49:45<24:39, 22.41s/it]




Epoch 1, Step 141: Loss = -0.3618, Reward Loss = -0.5714, Mask Loss = 0.4193 mask_mean = 0.4193 prob_mean = 0.5579 :  71%|███████▏  | 140/196 [53:34<20:56, 22.44s/it]




Epoch 1, Step 151: Loss = -0.3305, Reward Loss = -0.5403, Mask Loss = 0.4196 mask_mean = 0.4196 prob_mean = 0.5120 :  77%|███████▋  | 150/196 [57:22<17:10, 22.40s/it]




Epoch 1, Step 161: Loss = -0.3082, Reward Loss = -0.5201, Mask Loss = 0.4239 mask_mean = 0.4239 prob_mean = 0.5115 :  82%|████████▏ | 160/196 [1:01:09<13:25, 22.38s/it]




Epoch 1, Step 171: Loss = -0.3193, Reward Loss = -0.5231, Mask Loss = 0.4075 mask_mean = 0.4075 prob_mean = 0.5259 :  87%|████████▋ | 170/196 [1:04:57<09:42, 22.41s/it]




Epoch 1, Step 181: Loss = -0.3261, Reward Loss = -0.5245, Mask Loss = 0.3969 mask_mean = 0.3969 prob_mean = 0.5096 :  92%|█████████▏| 180/196 [1:08:45<05:58, 22.44s/it]




Epoch 1, Step 191: Loss = -0.3863, Reward Loss = -0.5796, Mask Loss = 0.3866 mask_mean = 0.3866 prob_mean = 0.5583 :  97%|█████████▋| 190/196 [1:12:33<02:14, 22.46s/it]




Epoch 1, Step 196: Loss = -0.3825, Reward Loss = -0.5705, Mask Loss = 0.3759 mask_mean = 0.3759 prob_mean = 0.5354 : 100%|██████████| 196/196 [1:14:15<00:00, 22.73s/it]
Epoch 2, Step 1: Loss = -0.3051, Reward Loss = -0.4955, Mask Loss = 0.3807 mask_mean = 0.3807 prob_mean = 0.4898 :   0%|          | 0/196 [00:23<?, ?it/s]




Epoch 2, Step 11: Loss = -0.2800, Reward Loss = -0.4716, Mask Loss = 0.3832 mask_mean = 0.3832 prob_mean = 0.4534 :   5%|▌         | 10/196 [04:11<1:09:31, 22.43s/it]




Epoch 2, Step 21: Loss = -0.3208, Reward Loss = -0.5090, Mask Loss = 0.3765 mask_mean = 0.3765 prob_mean = 0.5043 :  10%|█         | 20/196 [08:00<1:05:43, 22.40s/it]




Epoch 2, Step 31: Loss = -0.3140, Reward Loss = -0.5011, Mask Loss = 0.3741 mask_mean = 0.3741 prob_mean = 0.4947 :  15%|█▌        | 30/196 [11:50<1:02:28, 22.58s/it]




Epoch 2, Step 41: Loss = -0.2284, Reward Loss = -0.4167, Mask Loss = 0.3766 mask_mean = 0.3766 prob_mean = 0.4654 :  20%|██        | 40/196 [15:38<58:25, 22.47s/it]  




Epoch 2, Step 51: Loss = -0.3018, Reward Loss = -0.4910, Mask Loss = 0.3784 mask_mean = 0.3784 prob_mean = 0.4798 :  26%|██▌       | 50/196 [19:27<54:39, 22.46s/it]  




Epoch 2, Step 56: Loss = -0.2965, Reward Loss = -0.4844, Mask Loss = 0.3757 mask_mean = 0.3757 prob_mean = 0.4745 :  29%|██▊       | 56/196 [22:43<1:24:02, 36.02s/it]

In [None]:
3262

3262