In [1]:
from transformers import ViTMAEConfig, ViTMAEModel, ViTMAEForPreTraining
%load_ext autoreload


# Initializing a ViT MAE vit-mae-base style co`nfiguration
configuration = ViTMAEConfig(
    image_size=100,
    num_channels=1,
    hidden_size=480,
    intermediate_size=1024,
    decoder_intermediate_size=1024,
    patch_size=10,
    mask_ratio=0,
)

# Initializing a model (with random weights) from the vit-mae-base style configuration
model = ViTMAEForPreTraining(configuration)

# Accessing the model configuration
configuration = model.config

print('number of parameters: ', sum(p.numel() for p in model.parameters()))


number of parameters:  40191300


In [2]:
from dataloader import BATCH_SIZE, square_xrd_dataloader, square_binary_dataloader

import torch
from torch import nn, optim
mse_loss = nn.MSELoss()

def train_model(num_epochs=100):
    outputs = []
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)
    for epoch in range(num_epochs):
        for idx, data in enumerate(square_binary_dataloader):
            # ===================forward=====================
            output = model(data)
            loss = mse_loss(output.logits, data.squeeze())
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            if idx % 5 == 0:
                print(f"Finished batch {idx} in epoch {epoch + 1}. Loss: {loss.item():.4f}")

        print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()))
        outputs.append((epoch, data, output))



# Train the model

model.train(True)
train_model(num_epochs=1)
model.train(False)

Finished batch 0 in epoch 1. Loss: 1.1780
Finished batch 5 in epoch 1. Loss: 3.8377
Finished batch 10 in epoch 1. Loss: 1.4074
Finished batch 15 in epoch 1. Loss: 0.4227
Finished batch 20 in epoch 1. Loss: 0.3132
Finished batch 25 in epoch 1. Loss: 0.2035
Finished batch 30 in epoch 1. Loss: 0.1449
Finished batch 35 in epoch 1. Loss: 0.1011
Finished batch 40 in epoch 1. Loss: 0.0542
Finished batch 45 in epoch 1. Loss: 0.0345
Finished batch 50 in epoch 1. Loss: 0.0193
Finished batch 55 in epoch 1. Loss: 0.0139
Finished batch 60 in epoch 1. Loss: 0.0100
Finished batch 65 in epoch 1. Loss: 0.0078
Finished batch 70 in epoch 1. Loss: 0.0073
Finished batch 75 in epoch 1. Loss: 0.0073
Finished batch 80 in epoch 1. Loss: 0.0072
Finished batch 85 in epoch 1. Loss: 0.0067
Finished batch 90 in epoch 1. Loss: 0.0065
Finished batch 95 in epoch 1. Loss: 0.0063
Finished batch 100 in epoch 1. Loss: 0.0063
Finished batch 105 in epoch 1. Loss: 0.0063
Finished batch 110 in epoch 1. Loss: 0.0063
Finished b

KeyboardInterrupt: 

In [13]:
# this is a demo for testing

from transformers import AutoFeatureExtractor, ViTMAEForPreTraining
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-mae-base")
model_pretrained = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")

inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model_pretrained(**inputs)
loss = outputs.loss
mask = outputs.mask
ids_restore = outputs.ids_restore

inputs['pixel_values'].type()