In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

# Import our custom modules (assume they're in the same directory)
from notebook_1_data_prep_feature_extraction import MultiScaleFeatureExtractor
from notebook_2_hierarchical_transformer_implementation import HierarchicalTransformer

# Constants
NUM_EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_CLASSES = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data Loading
def load_data():
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    return trainloader, testloader

# Integrated Model
class IntegratedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = MultiScaleFeatureExtractor()
        self.hierarchical_transformer = HierarchicalTransformer(
            num_scales=3, d_model=256, nhead=8, num_classes=NUM_CLASSES
        )
    
    def forward(self, x):
        features = self.feature_extractor(x)
        return self.hierarchical_transformer(features)

# Training function
def train(model, trainloader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(enumerate(trainloader), total=len(trainloader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for i, (inputs, labels) in pbar:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'Loss': running_loss/(i+1), 'Acc': 100.*correct/total})
    
    return running_loss/len(trainloader), 100.*correct/total

# Evaluation function
def evaluate(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(testloader, desc="Evaluating"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return 100.*correct/total

# Main execution
if __name__ == "__main__":
    trainloader, testloader = load_data()
    
    model = IntegratedModel().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    print(f"Training on {DEVICE}")
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, epoch)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    
    test_acc = evaluate(model, testloader)
    print(f"Test Accuracy: {test_acc:.2f}%")

print("Training and evaluation complete.")