In [4]:
import torch
import sys
import os

# Add project root to path
sys.path.append(os.path.join(os.getcwd(), '..'))

from src.vit import VisionTransformer
from config.CONFIG import IMAGE_SIZE, PATCH_SIZE, NUM_CHANNELS, EMBED_DIM, NUM_CLASSES, DEPTH, ATTENTION_HEADS, HIDDEN_DIM

print("Libraries imported successfully")

Libraries imported successfully


In [5]:
# Buat model VIT
model = VisionTransformer(
    img_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=NUM_CHANNELS,
    num_classes=NUM_CLASSES,
    embed_dim=EMBED_DIM,
    depth=DEPTH,
    num_heads=ATTENTION_HEADS,
    mlp_ratio=HIDDEN_DIM,
)

print(f"Model created with embed_dim={EMBED_DIM}, depth={DEPTH}, num_heads={ATTENTION_HEADS}")

# Buat tensor input acak
batch_size = 2
input_tensor = torch.randn(batch_size, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
print(f"Input tensor shape: {input_tensor.shape}")

Model created with embed_dim=64, depth=12, num_heads=8
Input tensor shape: torch.Size([2, 1, 28, 28])


In [None]:
# Debugging forward pass step by step
print("=== Debugging VIT Forward Pass ===")

# 1. Patch Embedding
x = model.patch_embed(input_tensor)
print(f"After patch embedding: {x.shape}")  # Expected: [batch_size, num_patches, embed_dim]

# 2. Add CLS token
bs = x.shape[0]
cls_tokens = model.cls_token.expand(bs, -1, -1)
print(f"CLS token shape: {cls_tokens.shape}")
x = torch.cat((cls_tokens, x), dim=1)
print(f"After adding CLS token: {x.shape}")  # Expected: [batch_size, 1 + num_patches, embed_dim]

# 3. Add positional embedding
x = x + model.pos_embed
print(f"After positional embedding: {x.shape}")
x = model.pos_drop(x)

# 4. Transformer blocks
for i, block in enumerate(model.blocks):
    x = block(x)
    print(f"After transformer block {i+1}: {x.shape}")

# 5. Final layer norm
x = model.norm(x)
print(f"After final layer norm: {x.shape}")

# 6. Extract CLS token
cls_output = x[:, 0]
print(f"CLS token output shape: {cls_output.shape}")  # Expected: [batch_size, embed_dim]

# 7. Classifier head
output = model.head(cls_output)
print(f"Final output shape: {output.shape}")  # Expected: [batch_size, num_classes]

print("=== Debugging Complete ===")