Liked our work? give us a ⭐!
This repository contains unofficial implementation of ViT (Vision Transformer) that is introduced in the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale using PyTorch. Implementation has tested using the MNIST Dataset for image classification task.
- In order to use this code for images with multiple channels: change
self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)
toself.cls_token = nn.Parameter(torch.randn(size=(1, 1, embed_dim)), requires_grad=True)
.
We need two classes to implement ViT. First is the PatchEmbedding
to processing the image and embeddings until we feed the transformer encoder Second is the ViT
for the rest of the process.
class ViT(nn.Module):
def __init__(self, num_patches, img_size, num_classes, patch_size, embed_dim, num_encoders, num_heads, hidden_dim, dropout, activation, in_channels):
super().__init__()
self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True)
self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
self.mlp_head = nn.Sequential(
nn.LayerNorm(normalized_shape=embed_dim),
nn.Linear(in_features=embed_dim, out_features=num_classes)
)
def forward(self, x):
x = self.embeddings_block(x)
x = self.encoder_blocks(x)
x = self.mlp_head(x[:, 0, :])
return x
class PatchEmbedding(nn.Module):
def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
super().__init__()
self.patcher = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
),
nn.Flatten(2))
self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)
self.position_embeddings = nn.Parameter(torch.randn(size=(1, num_patches+1, embed_dim)), requires_grad=True)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = self.patcher(x).permute(0, 2, 1)
x = torch.cat([cls_token, x], dim=1)
x = self.position_embeddings + x
x = self.dropout(x)
return x
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY)
start = timeit.default_timer()
for epoch in tqdm(range(EPOCHS), position=0, leave=True):
model.train()
train_labels = []
train_preds = []
train_running_loss = 0
for idx, img_label in enumerate(tqdm(train_dataloader, position=0, leave=True)):
img = img_label["image"].float().to(device)
label = img_label["label"].type(torch.uint8).to(device)
y_pred = model(img)
y_pred_label = torch.argmax(y_pred, dim=1)
train_labels.extend(label.cpu().detach())
train_preds.extend(y_pred_label.cpu().detach())
loss = criterion(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_running_loss += loss.item()
train_loss = train_running_loss / (idx + 1)
model.eval()
val_labels = []
val_preds = []
val_running_loss = 0
with torch.no_grad():
for idx, img_label in enumerate(tqdm(val_dataloader, position=0, leave=True)):
img = img_label["image"].float().to(device)
label = img_label["label"].type(torch.uint8).to(device)
y_pred = model(img)
y_pred_label = torch.argmax(y_pred, dim=1)
val_labels.extend(label.cpu().detach())
val_preds.extend(y_pred_label.cpu().detach())
loss = criterion(y_pred, label)
val_running_loss += loss.item()
val_loss = val_running_loss / (idx + 1)
print("-"*30)
print(f"Train Loss EPOCH {epoch+1}: {train_loss:.4f}")
print(f"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}")
print(f"Train Accuracy EPOCH {epoch+1}: {sum(1 for x,y in zip(train_preds, train_labels) if x == y) / len(train_labels):.4f}")
print(f"Valid Accuracy EPOCH {epoch+1}: {sum(1 for x,y in zip(val_preds, val_labels) if x == y) / len(val_labels):.4f}")
print("-"*30)
stop = timeit.default_timer()
print(f"Training Time: {stop-start:.2f}s")
plt.figure()
f, axarr = plt.subplots(2, 3)
counter = 0
for i in range(2):
for j in range(3):
axarr[i][j].imshow(imgs[counter].squeeze(), cmap="gray")
axarr[i][j].set_title(f"Predicted {labels[counter]}")
counter += 1
You can run the code by downloading the notebook and updating the variables train_df
and test_df
to point a valid dataset location.
You can contact me with this email address: uygarsci@gmail.com