<a href="https://colab.research.google.com/github/williamedwardhahn/MathData25/blob/main/WandB_Version_Train_Alexnet_Bug_and_Bats_Alexnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load and Training

In [2]:
# ============================================
# 1️⃣ Setup & Imports
# ============================================
!pip install wandb gdown -q

import os
import gdown
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import copy
import time
import warnings

# ============================================
# 2️⃣ Initialize W&B
# ============================================
wandb.login()

wandb.init(project="alexnet-training-demo", config={
    "architecture": "AlexNet",
    "dataset": "bats_vs_bugs",
    "epochs": 10,
    "batch_size": 32,
    "learning_rate": 1e-4
})
config = wandb.config

# ============================================
# 3️⃣ Environment Setup
# ============================================
warnings.filterwarnings("ignore", message=".*can only test a child process.*")
warnings.filterwarnings("ignore", message=".*Bad file descriptor.*")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"✅ Using device: {device}")

# ============================================
# 4️⃣ Download Dataset from Google Drive
# ============================================
drive_url = "https://drive.google.com/drive/folders/1QmgIbDjfaiWbZxia4uJ84P_VTh8ggs1L?usp=drive_link"
folder_id = drive_url.split('/')[-1].split('?')[0]
local_path = "/content/data"
os.makedirs(local_path, exist_ok=True)

print("📦 Downloading dataset from Google Drive...")
!gdown --folder "https://drive.google.com/drive/folders/{folder_id}" -O {local_path} --quiet
print("✅ Download complete!")

# ============================================
# 5️⃣ Detect Train/Valid Folders
# ============================================
subdirs = [os.path.join(local_path, d) for d in os.listdir(local_path) if os.path.isdir(os.path.join(local_path, d))]
train_dir = next((d for d in subdirs if "train" in d.lower()), None)
valid_dir = next((d for d in subdirs if "valid" in d.lower() or "val" in d.lower()), None)

if not train_dir or not valid_dir:
    raise FileNotFoundError(f"Couldn't find 'train' or 'valid' folders. Found: {subdirs}")

print(f"✅ Found train: {train_dir}")
print(f"✅ Found valid: {valid_dir}")

# ============================================
# 6️⃣ Data Preparation
# ============================================
input_size = 224
batch_size = config.batch_size

train_transforms = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])
valid_transforms = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
valid_dataset = datasets.ImageFolder(valid_dir, transform=valid_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"📂 Train size: {len(train_dataset)} | Valid size: {len(valid_dataset)}")
print(f"🧠 Classes: {train_dataset.classes}")

# ============================================
# 7️⃣ Define Model, Loss, Optimizer
# ============================================
print("⚙️ Loading AlexNet...")
model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
num_classes = len(train_dataset.classes)
model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
model = model.to(device)
print("✅ Model ready!")

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
num_epochs = config.epochs

best_acc = 0.0
best_model_wts = copy.deepcopy(model.state_dict())

wandb.watch(model, log="all", log_freq=10)

# ============================================
# 8️⃣ Training & Validation Loop
# ============================================
for epoch in range(num_epochs):
    print(f"\n🔁 Epoch {epoch+1}/{num_epochs}")
    epoch_start = time.time()

    train_loss, valid_loss = 0.0, 0.0
    train_corrects, valid_corrects = 0, 0

    # ----- TRAIN -----
    model.train()
    for inputs, labels in tqdm(train_loader, desc="Training", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        train_loss += loss.item() * inputs.size(0)
        train_corrects += torch.sum(preds == labels.data)

    epoch_train_loss = train_loss / len(train_dataset)
    epoch_train_acc = train_corrects.double() / len(train_dataset)

    # ----- VALIDATION -----
    model.eval()
    with torch.no_grad():
        for inputs, labels in tqdm(valid_loader, desc="Validating", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            valid_loss += loss.item() * inputs.size(0)
            valid_corrects += torch.sum(preds == labels.data)

    epoch_valid_loss = valid_loss / len(valid_dataset)
    epoch_valid_acc = valid_corrects.double() / len(valid_dataset)

    # ----- LOG TO W&B -----
    wandb.log({
        "epoch": epoch + 1,
        "train/loss": epoch_train_loss,
        "train/accuracy": epoch_train_acc,
        "valid/loss": epoch_valid_loss,
        "valid/accuracy": epoch_valid_acc,
        "learning_rate": optimizer.param_groups[0]["lr"]
    })

    # ----- PRINT SUMMARY -----
    print(f"📊 Train Loss: {epoch_train_loss:.4f} | Train Acc: {epoch_train_acc:.4f}")
    print(f"📈 Valid Loss: {epoch_valid_loss:.4f} | Valid Acc: {epoch_valid_acc:.4f}")

    # ----- SAVE BEST MODEL -----
    if epoch_valid_acc > best_acc:
        best_acc = epoch_valid_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save(model.state_dict(), "best_model.pth")
        print(f"💾 Saved new best model (Acc: {best_acc:.4f})")

    print(f"⏱️ Epoch Time: {time.time() - epoch_start:.2f}s")

print("\n🏁 Training complete!")
print(f"🏆 Best Validation Accuracy: {best_acc:.4f}")

# ============================================
# 9️⃣ Finalize Run
# ============================================
model.load_state_dict(best_model_wts)
wandb.alert(title="Training Complete", text=f"Best validation accuracy: {best_acc:.4f}")
wandb.save("best_model.pth")
wandb.finish()

print("✅ Best model loaded and W&B run completed!")


✅ Using device: cpu
📦 Downloading dataset from Google Drive...
Failed to retrieve file url:

	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses.
	Check FAQ in https://github.com/wkentaro/gdown?tab=readme-ov-file#faq.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=17e9ISYrYLRRl6y1wS3ssA3uJE8O3JhUD

but Gdown can't. Please check connections and permissions.
✅ Download complete!
✅ Found train: /content/data/train
✅ Found valid: /content/data/valid
📂 Train size: 19 | Valid size: 20
🧠 Classes: ['bats', 'bugs']
⚙️ Loading AlexNet...
✅ Model ready!

🔁 Epoch 1/10


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]

📊 Train Loss: 0.9910 | Train Acc: 0.4737
📈 Valid Loss: 0.6093 | Valid Acc: 0.7500
💾 Saved new best model (Acc: 0.7500)
⏱️ Epoch Time: 6.71s

🔁 Epoch 2/10


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]

📊 Train Loss: 0.3012 | Train Acc: 0.8421
📈 Valid Loss: 0.4209 | Valid Acc: 0.8000
💾 Saved new best model (Acc: 0.8000)
⏱️ Epoch Time: 5.41s

🔁 Epoch 3/10


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]

📊 Train Loss: 0.0719 | Train Acc: 1.0000
📈 Valid Loss: 0.3161 | Valid Acc: 0.8500
💾 Saved new best model (Acc: 0.8500)
⏱️ Epoch Time: 6.61s

🔁 Epoch 4/10


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]

📊 Train Loss: 0.0279 | Train Acc: 1.0000
📈 Valid Loss: 0.2569 | Valid Acc: 0.8500
⏱️ Epoch Time: 4.92s

🔁 Epoch 5/10


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]

📊 Train Loss: 0.0057 | Train Acc: 1.0000
📈 Valid Loss: 0.2269 | Valid Acc: 0.8500
⏱️ Epoch Time: 7.78s

🔁 Epoch 6/10


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]

📊 Train Loss: 0.0128 | Train Acc: 1.0000
📈 Valid Loss: 0.2065 | Valid Acc: 0.9000
💾 Saved new best model (Acc: 0.9000)
⏱️ Epoch Time: 5.29s

🔁 Epoch 7/10


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]

📊 Train Loss: 0.0014 | Train Acc: 1.0000
📈 Valid Loss: 0.1943 | Valid Acc: 0.9000
⏱️ Epoch Time: 5.11s

🔁 Epoch 8/10


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]

📊 Train Loss: 0.0005 | Train Acc: 1.0000
📈 Valid Loss: 0.1863 | Valid Acc: 0.9000
⏱️ Epoch Time: 5.84s

🔁 Epoch 9/10


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]

📊 Train Loss: 0.0001 | Train Acc: 1.0000
📈 Valid Loss: 0.1797 | Valid Acc: 0.9000
⏱️ Epoch Time: 4.81s

🔁 Epoch 10/10


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Validating:   0%|          | 0/1 [00:00<?, ?it/s]

📊 Train Loss: 0.0000 | Train Acc: 1.0000
📈 Valid Loss: 0.1753 | Valid Acc: 0.9500
💾 Saved new best model (Acc: 0.9500)
⏱️ Epoch Time: 9.94s

🏁 Training complete!
🏆 Best Validation Accuracy: 0.9500


0,1
epoch,▁▂▃▃▄▅▆▆▇█
learning_rate,▁▁▁▁▁▁▁▁▁▁
train/accuracy,▁▆████████
train/loss,█▃▂▁▁▁▁▁▁▁
valid/accuracy,▁▃▅▅▅▆▆▆▆█
valid/loss,█▅▃▂▂▂▁▁▁▁

0,1
epoch,10.0
learning_rate,0.0001
train/accuracy,1.0
train/loss,4e-05
valid/accuracy,0.95
valid/loss,0.1753


✅ Best model loaded and W&B run completed!
