In [2]:
import timm
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import os
from torch import nn, optim
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler

# Set paths and parameters
model_name = "hf_hub:timm/vit_base_patch16_224.mae"
data_dir = "path_to_imagenet1k"
output_dir = "./vit_checkpoints"
batch_size = 128
num_steps = 100000
save_every = 1000

# Create the model
model = timm.create_model(model_name, pretrained=True, num_classes=1000)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Data transforms and loading
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

# Load the ImageNet-1k dataset
train_dataset = ImageFolder(root=os.path.join(data_dir, 'train'), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = create_optimizer_v2(model, 'adamw', lr=3e-5, weight_decay=0.01)
scheduler, _ = create_scheduler(100, optimizer, num_steps)

# Training loop
def train_model(model, train_loader, criterion, optimizer, scheduler, num_steps, save_every, output_dir):
    model.train()
    step = 0
    
    while step < num_steps:
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            if step % save_every == 0:
                save_path = os.path.join(output_dir, f"checkpoint_step_{step}.pth")
                torch.save(model.state_dict(), save_path)
                print(f"Checkpoint saved at step {step}: {save_path}")
            
            step += 1
            if step >= num_steps:
                break

    print("Training complete.")

# Make sure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Start training
train_model(model, train_loader, criterion, optimizer, scheduler, num_steps, save_every, output_dir)

config.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

FileNotFoundError: [Errno 2] No such file or directory: 'path_to_imagenet1k/train'