In [1]:
import torch 
from transformers import ViTImageProcessor, ViTForImageClassification, ViTModel, ViTConfig
from PIL import Image
import requests
import matplotlib.pyplot as plt

import numpy as np
import cv2
from datasets import load_dataset,load_metric
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

from accelerate import Accelerator

accelerator = Accelerator()
device = accelerator.device


  from .autonotebook import tqdm as notebook_tqdm
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [2]:
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# url = "http://farm3.staticflickr.com/2066/1798910782_5536af8767_z.jpg"
# url = "http://farm1.staticflickr.com/184/399924547_98e6cef97a_z.jpg"
url = "http://farm1.staticflickr.com/128/318959350_1a39aae18c_z.jpg"
image = Image.open(requests.get(url, stream=True).raw)

pretrained_name = 'google/vit-base-patch16-224'
# pretrained_name = 'vit-base-patch16-224-finetuned-imageneteval'
# pretrained_name = 'openai/clip-vit-base-patch32'
config = ViTConfig.from_pretrained(pretrained_name)
processor = ViTImageProcessor.from_pretrained(pretrained_name)
# get mean and std to unnormalize the processed images
mean, std = processor.image_mean, processor.image_std

pred_model = ViTForImageClassification.from_pretrained(pretrained_name)
pred_model.to(device)
# set to eval mode
pred_model.eval()

with torch.no_grad():
    inputs = processor(images=image, return_tensors="pt")
    inputs.to(device)
    outputs = pred_model(**inputs, output_hidden_states=True)
    logits = outputs.logits
    # model predicts one of the 1000 ImageNet classes
    predicted_class_idx = logits.argmax(-1).item()
    print("Predicted class:", pred_model.config.id2label[predicted_class_idx])

Predicted class: golden retriever


In [3]:
from maskgen.models.mask_generating_model8 import MaskGeneratingModel

mask_gen_model = MaskGeneratingModel(pred_model, hidden_size=config.hidden_size, num_classes=config.num_labels)
mask_gen_model.to(device)
print()




# training

In [4]:
from torch.utils.data import DataLoader
def load_data(seed=42): 
    dataset = load_dataset("mrm8488/ImageNet1K-val")
    dataset = dataset['train']
    splits = dataset.train_test_split(test_size=0.1, seed=seed)
    test_ds = splits['test']
    splits = splits['train'].train_test_split(test_size=0.1, seed=seed)
    train_ds = splits['train']
    val_ds = splits['test']
    return train_ds, val_ds, test_ds

train_ds, val_ds, test_ds = load_data()

normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
if "height" in processor.size:
    size = (processor.size["height"], processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in processor.size:
    size = processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = processor.size.get("longest_edge")

transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

def preprocess(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_ds.set_transform(preprocess)


Repo card metadata block was not found. Setting CardData to empty.


# training params

In [5]:
batch_size = 256
num_workers = 8
train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers)
n_steps = 10

params_to_optimize = [name for name, param in mask_gen_model.named_parameters() if param.requires_grad]
print("params to be optimized: ")
print(params_to_optimize)

params to be optimized: 
['similarity_measure.logit_scale', 'similarity_measure.query.weight', 'similarity_measure.query.bias', 'similarity_measure.key.weight', 'similarity_measure.key.bias']


In [6]:
from tqdm import tqdm

params_to_optimize = [param for param in mask_gen_model.parameters() if param.requires_grad]
# optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.9)


print()

for epoch in range(10):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        pixel_values = data['pixel_values'].to(device)
        loss_dict = mask_gen_model.train_one_batch(pixel_values, optimizer=optimizer, n_steps=n_steps)
        # scheduler.step()
        pbar.set_description(f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss_dict['loss'].item():.4f}, " 
                             f"Reward Loss = {loss_dict['reward_loss'].item():.4f}, "
                            #  f"Regret Loss = {loss_dict['regret_loss'].item():.4f}, "
                             f"Mask Loss = {loss_dict['mask_loss'].item():.4f} "
                             f"alt_mask_loss = {loss_dict['alt_mask_loss'].item():.4f}"
                             f"mask_mean = {loss_dict['mask_mean'].item():.4f} "
                             f"prob_mean = {loss_dict['prob_mean'].item():.4f} "
                             )
        if idx % 10 == 0:
            print()
        if (idx) % 100 == 0:
            
            torch.save(mask_gen_model.state_dict(), f'mask_gen_model/mask_gen_model_{epoch}_{idx}.pth') 



torch.save(mask_gen_model.state_dict(), f'mask_gen_model/mask_gen_model_final_{epoch}_{idx}.pth') 





Epoch 1, Step 1: Loss = 0.4750, Reward Loss = 0.3881, Mask Loss = 0.0870 alt_mask_loss = 0.8601mask_mean = 0.4998 prob_mean = 0.5613 :   0%|          | 0/159 [00:06<?, ?it/s]




Epoch 1, Step 11: Loss = 0.1321, Reward Loss = 0.1205, Mask Loss = 0.0192 alt_mask_loss = 0.1144mask_mean = 0.1940 prob_mean = 0.2280 :   7%|▋         | 11/159 [00:58<11:37,  4.72s/it]




Epoch 1, Step 21: Loss = 0.1068, Reward Loss = 0.0986, Mask Loss = 0.0143 alt_mask_loss = 0.0798mask_mean = 0.1652 prob_mean = 0.1885 :  13%|█▎        | 21/159 [01:44<10:43,  4.67s/it]




Epoch 1, Step 31: Loss = 0.1307, Reward Loss = 0.1184, Mask Loss = 0.0218 alt_mask_loss = 0.1209mask_mean = 0.2041 prob_mean = 0.2250 :  19%|█▉        | 31/159 [02:31<09:55,  4.65s/it]




Epoch 1, Step 41: Loss = 0.1145, Reward Loss = 0.1059, Mask Loss = 0.0168 alt_mask_loss = 0.0839mask_mean = 0.1756 prob_mean = 0.2035 :  26%|██▌       | 41/159 [03:17<09:06,  4.63s/it]




Epoch 1, Step 51: Loss = 0.1267, Reward Loss = 0.1155, Mask Loss = 0.0203 alt_mask_loss = 0.1102mask_mean = 0.1904 prob_mean = 0.2238 :  32%|███▏      | 51/159 [04:04<08:19,  4.63s/it]




Epoch 1, Step 61: Loss = 0.1131, Reward Loss = 0.1042, Mask Loss = 0.0173 alt_mask_loss = 0.0866mask_mean = 0.1808 prob_mean = 0.2023 :  38%|███▊      | 61/159 [04:50<07:33,  4.63s/it]




Epoch 1, Step 71: Loss = 0.1324, Reward Loss = 0.1205, Mask Loss = 0.0205 alt_mask_loss = 0.1161mask_mean = 0.1833 prob_mean = 0.2340 :  45%|████▍     | 71/159 [05:36<06:46,  4.62s/it]




Epoch 1, Step 81: Loss = 0.1192, Reward Loss = 0.1086, Mask Loss = 0.0182 alt_mask_loss = 0.1038mask_mean = 0.1845 prob_mean = 0.2107 :  51%|█████     | 81/159 [06:22<06:00,  4.62s/it]




Epoch 1, Step 91: Loss = 0.1303, Reward Loss = 0.1182, Mask Loss = 0.0203 alt_mask_loss = 0.1194mask_mean = 0.1844 prob_mean = 0.2302 :  57%|█████▋    | 91/159 [07:09<05:14,  4.62s/it]




Epoch 1, Step 101: Loss = 0.1213, Reward Loss = 0.1101, Mask Loss = 0.0191 alt_mask_loss = 0.1092mask_mean = 0.1892 prob_mean = 0.2125 :  63%|██████▎   | 100/159 [07:55<04:32,  4.62s/it]




Epoch 1, Step 111: Loss = 0.1264, Reward Loss = 0.1161, Mask Loss = 0.0197 alt_mask_loss = 0.1007mask_mean = 0.1816 prob_mean = 0.2272 :  70%|██████▉   | 111/159 [08:47<03:45,  4.70s/it]




Epoch 1, Step 121: Loss = 0.1026, Reward Loss = 0.0947, Mask Loss = 0.0163 alt_mask_loss = 0.0775mask_mean = 0.1828 prob_mean = 0.1848 :  76%|███████▌  | 121/159 [09:33<02:56,  4.65s/it]




Epoch 1, Step 131: Loss = 0.1284, Reward Loss = 0.1169, Mask Loss = 0.0199 alt_mask_loss = 0.1126mask_mean = 0.1838 prob_mean = 0.2282 :  82%|████████▏ | 131/159 [10:20<02:10,  4.65s/it]




Epoch 1, Step 141: Loss = 0.1101, Reward Loss = 0.1017, Mask Loss = 0.0169 alt_mask_loss = 0.0818mask_mean = 0.1808 prob_mean = 0.1974 :  89%|████████▊ | 141/159 [11:07<01:23,  4.66s/it]




Epoch 1, Step 151: Loss = 0.1200, Reward Loss = 0.1099, Mask Loss = 0.0189 alt_mask_loss = 0.0987mask_mean = 0.1843 prob_mean = 0.2143 :  95%|█████████▍| 151/159 [11:53<00:37,  4.65s/it]




Epoch 1, Step 159: Loss = 0.1874, Reward Loss = 0.1669, Mask Loss = 0.0284 alt_mask_loss = 0.2023mask_mean = 0.1815 prob_mean = 0.3247 : 100%|██████████| 159/159 [12:27<00:00,  4.70s/it]
Epoch 2, Step 1: Loss = 0.1177, Reward Loss = 0.1083, Mask Loss = 0.0185 alt_mask_loss = 0.0924mask_mean = 0.1832 prob_mean = 0.2105 :   0%|          | 0/159 [00:06<?, ?it/s]




Epoch 2, Step 11: Loss = 0.1143, Reward Loss = 0.1045, Mask Loss = 0.0179 alt_mask_loss = 0.0968mask_mean = 0.1855 prob_mean = 0.2037 :   7%|▋         | 11/159 [00:58<11:35,  4.70s/it]




Epoch 2, Step 21: Loss = 0.1219, Reward Loss = 0.1111, Mask Loss = 0.0188 alt_mask_loss = 0.1056mask_mean = 0.1819 prob_mean = 0.2177 :  13%|█▎        | 21/159 [01:44<10:41,  4.65s/it]




Epoch 2, Step 31: Loss = 0.1142, Reward Loss = 0.1041, Mask Loss = 0.0175 alt_mask_loss = 0.0992mask_mean = 0.1851 prob_mean = 0.2021 :  19%|█▉        | 31/159 [02:30<09:54,  4.65s/it]




Epoch 2, Step 41: Loss = 0.1130, Reward Loss = 0.1036, Mask Loss = 0.0178 alt_mask_loss = 0.0922mask_mean = 0.1856 prob_mean = 0.2017 :  26%|██▌       | 41/159 [03:17<09:08,  4.65s/it]




Epoch 2, Step 51: Loss = 0.1289, Reward Loss = 0.1172, Mask Loss = 0.0197 alt_mask_loss = 0.1150mask_mean = 0.1815 prob_mean = 0.2281 :  32%|███▏      | 51/159 [04:03<08:21,  4.65s/it]




Epoch 2, Step 61: Loss = 0.1249, Reward Loss = 0.1140, Mask Loss = 0.0193 alt_mask_loss = 0.1069mask_mean = 0.1838 prob_mean = 0.2212 :  38%|███▊      | 61/159 [04:50<07:35,  4.65s/it]




Epoch 2, Step 71: Loss = 0.1300, Reward Loss = 0.1179, Mask Loss = 0.0200 alt_mask_loss = 0.1199mask_mean = 0.1821 prob_mean = 0.2289 :  45%|████▍     | 71/159 [05:36<06:48,  4.65s/it]




Epoch 2, Step 81: Loss = 0.1405, Reward Loss = 0.1270, Mask Loss = 0.0218 alt_mask_loss = 0.1322mask_mean = 0.1845 prob_mean = 0.2477 :  51%|█████     | 81/159 [06:23<06:02,  4.65s/it]




Epoch 2, Step 91: Loss = 0.1436, Reward Loss = 0.1304, Mask Loss = 0.0218 alt_mask_loss = 0.1303mask_mean = 0.1803 prob_mean = 0.2547 :  57%|█████▋    | 91/159 [07:09<05:16,  4.65s/it]




Epoch 2, Step 101: Loss = 0.0945, Reward Loss = 0.0878, Mask Loss = 0.0145 alt_mask_loss = 0.0664mask_mean = 0.1820 prob_mean = 0.1690 :  63%|██████▎   | 100/159 [07:56<04:34,  4.65s/it]




Epoch 2, Step 111: Loss = 0.1151, Reward Loss = 0.1054, Mask Loss = 0.0174 alt_mask_loss = 0.0952mask_mean = 0.1796 prob_mean = 0.2042 :  70%|██████▉   | 111/159 [08:48<03:44,  4.69s/it]




Epoch 2, Step 121: Loss = 0.1377, Reward Loss = 0.1256, Mask Loss = 0.0208 alt_mask_loss = 0.1190mask_mean = 0.1796 prob_mean = 0.2440 :  76%|███████▌  | 121/159 [09:34<02:56,  4.65s/it]




Epoch 2, Step 131: Loss = 0.1146, Reward Loss = 0.1048, Mask Loss = 0.0173 alt_mask_loss = 0.0962mask_mean = 0.1798 prob_mean = 0.2030 :  82%|████████▏ | 131/159 [10:21<02:10,  4.64s/it]




Epoch 2, Step 141: Loss = 0.1388, Reward Loss = 0.1259, Mask Loss = 0.0211 alt_mask_loss = 0.1266mask_mean = 0.1804 prob_mean = 0.2468 :  89%|████████▊ | 141/159 [11:07<01:23,  4.65s/it]




Epoch 2, Step 151: Loss = 0.1098, Reward Loss = 0.1009, Mask Loss = 0.0164 alt_mask_loss = 0.0875mask_mean = 0.1780 prob_mean = 0.1951 :  95%|█████████▍| 151/159 [11:53<00:37,  4.65s/it]




Epoch 2, Step 159: Loss = 0.1167, Reward Loss = 0.1068, Mask Loss = 0.0182 alt_mask_loss = 0.0970mask_mean = 0.1815 prob_mean = 0.2123 : 100%|██████████| 159/159 [12:27<00:00,  4.70s/it]
Epoch 3, Step 1: Loss = 0.1161, Reward Loss = 0.1067, Mask Loss = 0.0180 alt_mask_loss = 0.0919mask_mean = 0.1842 prob_mean = 0.2065 :   0%|          | 0/159 [00:05<?, ?it/s]




Epoch 3, Step 11: Loss = 0.1099, Reward Loss = 0.1010, Mask Loss = 0.0168 alt_mask_loss = 0.0876mask_mean = 0.1826 prob_mean = 0.1956 :   7%|▋         | 11/159 [00:57<11:35,  4.70s/it]




Epoch 3, Step 21: Loss = 0.1189, Reward Loss = 0.1089, Mask Loss = 0.0184 alt_mask_loss = 0.0987mask_mean = 0.1797 prob_mean = 0.2122 :  13%|█▎        | 21/159 [01:44<10:41,  4.65s/it]




Epoch 3, Step 31: Loss = 0.1190, Reward Loss = 0.1092, Mask Loss = 0.0178 alt_mask_loss = 0.0960mask_mean = 0.1814 prob_mean = 0.2119 :  19%|█▉        | 31/159 [02:30<09:54,  4.64s/it]




Epoch 3, Step 41: Loss = 0.1303, Reward Loss = 0.1185, Mask Loss = 0.0198 alt_mask_loss = 0.1157mask_mean = 0.1832 prob_mean = 0.2293 :  26%|██▌       | 41/159 [03:17<09:08,  4.65s/it]




Epoch 3, Step 51: Loss = 0.1373, Reward Loss = 0.1249, Mask Loss = 0.0209 alt_mask_loss = 0.1223mask_mean = 0.1808 prob_mean = 0.2426 :  32%|███▏      | 51/159 [04:03<08:22,  4.65s/it]




Epoch 3, Step 61: Loss = 0.1135, Reward Loss = 0.1042, Mask Loss = 0.0173 alt_mask_loss = 0.0915mask_mean = 0.1814 prob_mean = 0.2025 :  38%|███▊      | 61/159 [04:50<07:35,  4.65s/it]




Epoch 3, Step 71: Loss = 0.1108, Reward Loss = 0.1012, Mask Loss = 0.0176 alt_mask_loss = 0.0947mask_mean = 0.1846 prob_mean = 0.1976 :  45%|████▍     | 71/159 [05:36<06:49,  4.65s/it]




Epoch 3, Step 81: Loss = 0.1263, Reward Loss = 0.1146, Mask Loss = 0.0201 alt_mask_loss = 0.1158mask_mean = 0.1866 prob_mean = 0.2231 :  51%|█████     | 81/159 [06:23<06:03,  4.66s/it]




Epoch 3, Step 88: Loss = 0.1339, Reward Loss = 0.1217, Mask Loss = 0.0201 alt_mask_loss = 0.1197mask_mean = 0.1801 prob_mean = 0.2355 :  55%|█████▌    | 88/159 [06:59<06:37,  5.60s/it]

In [None]:
3262

3262