In [1]:
import torch 
from transformers import ViTImageProcessor, ViTForImageClassification, ViTModel, ViTConfig
from PIL import Image
import requests
import matplotlib.pyplot as plt

import numpy as np
import cv2
from datasets import load_dataset,load_metric
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

from accelerate import Accelerator

accelerator = Accelerator()
device = accelerator.device


  from .autonotebook import tqdm as notebook_tqdm
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [2]:
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# url = "http://farm3.staticflickr.com/2066/1798910782_5536af8767_z.jpg"
# url = "http://farm1.staticflickr.com/184/399924547_98e6cef97a_z.jpg"
url = "http://farm1.staticflickr.com/128/318959350_1a39aae18c_z.jpg"
image = Image.open(requests.get(url, stream=True).raw)

pretrained_name = 'google/vit-base-patch16-224'
# pretrained_name = 'vit-base-patch16-224-finetuned-imageneteval'
# pretrained_name = 'openai/clip-vit-base-patch32'
config = ViTConfig.from_pretrained(pretrained_name)
processor = ViTImageProcessor.from_pretrained(pretrained_name)
# get mean and std to unnormalize the processed images
mean, std = processor.image_mean, processor.image_std

pred_model = ViTForImageClassification.from_pretrained(pretrained_name)
pred_model.to(device)
# set to eval mode
pred_model.eval()

with torch.no_grad():
    inputs = processor(images=image, return_tensors="pt")
    inputs.to(device)
    outputs = pred_model(**inputs, output_hidden_states=True)
    logits = outputs.logits
    # model predicts one of the 1000 ImageNet classes
    predicted_class_idx = logits.argmax(-1).item()
    print("Predicted class:", pred_model.config.id2label[predicted_class_idx])

Predicted class: golden retriever


In [3]:
from maskgen.models.mask_generating_model12 import MaskGeneratingModel

mask_gen_model = MaskGeneratingModel(pred_model, hidden_size=config.hidden_size, num_classes=config.num_labels)
mask_gen_model.to(device)
print()




# training

In [4]:
from torch.utils.data import DataLoader
def load_data(seed=42): 
    dataset = load_dataset("mrm8488/ImageNet1K-val")
    dataset = dataset['train']
    splits = dataset.train_test_split(test_size=0.1, seed=seed)
    test_ds = splits['test']
    splits = splits['train'].train_test_split(test_size=0.1, seed=seed)
    train_ds = splits['train']
    val_ds = splits['test']
    return train_ds, val_ds, test_ds

train_ds, val_ds, test_ds = load_data()

normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
if "height" in processor.size:
    size = (processor.size["height"], processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in processor.size:
    size = processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = processor.size.get("longest_edge")

transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

def preprocess(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_ds.set_transform(preprocess)


Repo card metadata block was not found. Setting CardData to empty.


# training params

In [5]:
batch_size = 256
num_workers = 8
train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers)
n_steps = 10

params_to_optimize = [name for name, param in mask_gen_model.named_parameters() if param.requires_grad]
print("params to be optimized: ")
print(params_to_optimize)

params to be optimized: 
['similarity_measure.logit_scale', 'similarity_measure.pred_map.input_layer.weight', 'similarity_measure.pred_map.input_layer.bias', 'similarity_measure.pred_map.layers.0.0.weight', 'similarity_measure.pred_map.layers.0.0.bias', 'similarity_measure.pred_map.layers.0.3.weight', 'similarity_measure.pred_map.layers.0.3.bias', 'similarity_measure.pred_map.layers.0.6.weight', 'similarity_measure.pred_map.layers.0.6.bias', 'similarity_measure.pred_map.layers.1.0.weight', 'similarity_measure.pred_map.layers.1.0.bias', 'similarity_measure.pred_map.layers.1.3.weight', 'similarity_measure.pred_map.layers.1.3.bias', 'similarity_measure.pred_map.layers.1.6.weight', 'similarity_measure.pred_map.layers.1.6.bias', 'similarity_measure.pred_map.output_layer.weight', 'similarity_measure.pred_map.output_layer.bias', 'similarity_measure.explain_map.input_layer.weight', 'similarity_measure.explain_map.input_layer.bias', 'similarity_measure.explain_map.layers.0.0.weight', 'similarit

In [6]:
from tqdm import tqdm

params_to_optimize = [param for param in mask_gen_model.parameters() if param.requires_grad]
# optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.9)


print()

for epoch in range(10):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        pixel_values = data['pixel_values'].to(device)
        loss_dict = mask_gen_model.train_one_batch(pixel_values, optimizer=optimizer, n_steps=n_steps)
        # scheduler.step()
        pbar.set_description(f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss_dict['loss'].item():.4f}, " 
                             f"Reward Loss = {loss_dict['reward_loss'].item():.4f}, "
                             f"Regret Loss = {loss_dict['regret_loss'].item():.4f}, "
                             f"Mask Loss = {loss_dict['mask_loss'].item():.4f} "
                             f"alt_mask_loss = {loss_dict['alt_mask_loss'].item():.4f}"
                             f"mask_mean = {loss_dict['mask_mean'].item():.4f} "
                             f"prob_mean = {loss_dict['prob_mean'].item():.4f} "
                             )
        if idx % 10 == 0:
            print()
        if (idx) % 30 == 0:
            
            torch.save(mask_gen_model.state_dict(), f'mask_gen_model/mask_gen_model_{epoch}_{idx}.pth') 



torch.save(mask_gen_model.state_dict(), f'mask_gen_model/mask_gen_model_final_{epoch}_{idx}.pth') 





Epoch 1, Step 1: Loss = 0.4140, Reward Loss = 0.4030, Regret Loss = 0.1692, Mask Loss = 0.5087 alt_mask_loss = 1.1055mask_mean = 0.5072 prob_mean = 0.5827 :   0%|          | 0/159 [00:16<?, ?it/s]




Epoch 1, Step 11: Loss = 0.1257, Reward Loss = 0.1238, Regret Loss = 0.8926, Mask Loss = 0.1967 alt_mask_loss = 0.1860mask_mean = 0.1910 prob_mean = 0.2486 :   7%|▋         | 11/159 [02:51<37:28, 15.19s/it]




Epoch 1, Step 21: Loss = 0.1666, Reward Loss = 0.1637, Regret Loss = 0.7182, Mask Loss = 0.2436 alt_mask_loss = 0.2834mask_mean = 0.2319 prob_mean = 0.2978 :  13%|█▎        | 21/159 [05:21<34:32, 15.01s/it]




Epoch 1, Step 31: Loss = 0.2207, Reward Loss = 0.2165, Regret Loss = 0.6554, Mask Loss = 0.2860 alt_mask_loss = 0.4217mask_mean = 0.2576 prob_mean = 0.3772 :  19%|█▉        | 30/159 [07:51<32:15, 15.00s/it]




Epoch 1, Step 41: Loss = 0.1214, Reward Loss = 0.1194, Regret Loss = 0.9819, Mask Loss = 0.2219 alt_mask_loss = 0.2056mask_mean = 0.1826 prob_mean = 0.2432 :  26%|██▌       | 41/159 [10:53<36:15, 18.44s/it]




Epoch 1, Step 51: Loss = 0.1972, Reward Loss = 0.1934, Regret Loss = 0.7126, Mask Loss = 0.3066 alt_mask_loss = 0.3843mask_mean = 0.2377 prob_mean = 0.3591 :  32%|███▏      | 51/159 [13:23<27:12, 15.12s/it]




Epoch 1, Step 61: Loss = 0.1737, Reward Loss = 0.1703, Regret Loss = 0.7753, Mask Loss = 0.2869 alt_mask_loss = 0.3375mask_mean = 0.2195 prob_mean = 0.3282 :  38%|███▊      | 60/159 [15:53<24:47, 15.03s/it]




Epoch 1, Step 71: Loss = 0.1025, Reward Loss = 0.1009, Regret Loss = 0.9814, Mask Loss = 0.2353 alt_mask_loss = 0.1653mask_mean = 0.1765 prob_mean = 0.2131 :  45%|████▍     | 71/159 [18:56<22:44, 15.50s/it]




Epoch 1, Step 81: Loss = 0.2228, Reward Loss = 0.2184, Regret Loss = 0.7646, Mask Loss = 0.3091 alt_mask_loss = 0.4400mask_mean = 0.2400 prob_mean = 0.4097 :  51%|█████     | 81/159 [21:26<19:31, 15.02s/it]




Epoch 1, Step 91: Loss = 0.1615, Reward Loss = 0.1586, Regret Loss = 0.8333, Mask Loss = 0.2764 alt_mask_loss = 0.2930mask_mean = 0.2117 prob_mean = 0.3126 :  57%|█████▋    | 90/159 [24:23<23:02, 20.04s/it]




Epoch 1, Step 101: Loss = 0.1987, Reward Loss = 0.1947, Regret Loss = 0.8592, Mask Loss = 0.3046 alt_mask_loss = 0.3957mask_mean = 0.2211 prob_mean = 0.3818 :  64%|██████▎   | 101/159 [26:58<14:42, 15.22s/it]




Epoch 1, Step 111: Loss = 0.1892, Reward Loss = 0.1855, Regret Loss = 0.7670, Mask Loss = 0.3096 alt_mask_loss = 0.3671mask_mean = 0.2307 prob_mean = 0.3551 :  70%|██████▉   | 111/159 [29:55<15:37, 19.54s/it]




Epoch 1, Step 121: Loss = 0.1331, Reward Loss = 0.1308, Regret Loss = 0.8926, Mask Loss = 0.2541 alt_mask_loss = 0.2290mask_mean = 0.1885 prob_mean = 0.2702 :  75%|███████▌  | 120/159 [32:52<11:56, 18.38s/it]




Epoch 1, Step 131: Loss = 0.1791, Reward Loss = 0.1757, Regret Loss = 0.8903, Mask Loss = 0.2871 alt_mask_loss = 0.3425mask_mean = 0.2129 prob_mean = 0.3482 :  82%|████████▏ | 131/159 [35:26<07:02, 15.10s/it]




Epoch 1, Step 141: Loss = 0.1965, Reward Loss = 0.1925, Regret Loss = 0.7628, Mask Loss = 0.3267 alt_mask_loss = 0.4003mask_mean = 0.2316 prob_mean = 0.3712 :  89%|████████▊ | 141/159 [37:56<04:30, 15.00s/it]




Epoch 1, Step 151: Loss = 0.1660, Reward Loss = 0.1628, Regret Loss = 0.9118, Mask Loss = 0.2874 alt_mask_loss = 0.3131mask_mean = 0.2006 prob_mean = 0.3332 :  94%|█████████▍| 150/159 [41:01<02:35, 17.27s/it]




Epoch 1, Step 159: Loss = 0.1528, Reward Loss = 0.1498, Regret Loss = 0.8862, Mask Loss = 0.2899 alt_mask_loss = 0.3010mask_mean = 0.2032 prob_mean = 0.3029 : 100%|██████████| 159/159 [43:14<00:00, 16.32s/it]
Epoch 2, Step 1: Loss = 0.1720, Reward Loss = 0.1686, Regret Loss = 0.9114, Mask Loss = 0.2959 alt_mask_loss = 0.3338mask_mean = 0.2090 prob_mean = 0.3375 :   0%|          | 0/159 [00:17<?, ?it/s]




Epoch 2, Step 11: Loss = 0.1928, Reward Loss = 0.1888, Regret Loss = 0.8265, Mask Loss = 0.3369 alt_mask_loss = 0.4017mask_mean = 0.2268 prob_mean = 0.3725 :   7%|▋         | 11/159 [02:51<37:03, 15.02s/it]




Epoch 2, Step 21: Loss = 0.2474, Reward Loss = 0.2417, Regret Loss = 0.7374, Mask Loss = 0.3941 alt_mask_loss = 0.5681mask_mean = 0.2574 prob_mean = 0.4780 :  13%|█▎        | 21/159 [05:59<47:45, 20.76s/it]




Epoch 2, Step 31: Loss = 0.1648, Reward Loss = 0.1614, Regret Loss = 0.9199, Mask Loss = 0.3176 alt_mask_loss = 0.3357mask_mean = 0.2001 prob_mean = 0.3411 :  19%|█▉        | 30/159 [08:28<32:41, 15.20s/it]




Epoch 2, Step 41: Loss = 0.1194, Reward Loss = 0.1173, Regret Loss = 1.0752, Mask Loss = 0.2424 alt_mask_loss = 0.2096mask_mean = 0.1667 prob_mean = 0.2626 :  26%|██▌       | 41/159 [11:02<29:31, 15.01s/it]




Epoch 2, Step 50: Loss = 0.2399, Reward Loss = 0.2347, Regret Loss = 0.7584, Mask Loss = 0.3739 alt_mask_loss = 0.5208mask_mean = 0.2553 prob_mean = 0.4573 :  31%|███▏      | 50/159 [13:17<27:11, 14.96s/it]

In [None]:
3262

3262