In [None]:
from src.dataset import load_dog_breed_dataset
from src.preprocessing import get_feature_extractor, preprocess_data
from src.model import initialize_model
from src.training import train_one_epoch, validate_model
from torch.utils.data import DataLoader
from configs.config import BATCH_SIZE, LEARNING_RATE, EPOCHS
import torch

# Load dataset
dataset = load_dog_breed_dataset()
feature_extractor = get_feature_extractor()
dataset = preprocess_data(dataset, feature_extractor)

# Create DataLoaders
def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.tensor([item["label"] for item in batch])
    return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(dataset['validation'], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

# Initialize model
num_labels = len(dataset['train'].features['label'].names)
model = initialize_model(num_labels)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.CrossEntropyLoss()

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for epoch in range(EPOCHS):
    train_one_epoch(model, train_dataloader, optimizer, loss_fn, device)
    accuracy = validate_model(model, val_dataloader, device)
    print(f"Epoch {epoch + 1}/{EPOCHS}, Validation Accuracy: {accuracy}")

# Save the model
model.save_pretrained("./models/dog_breed_vit")
