In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt # Keep for potential debugging/local plotting
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision import transforms, models # Import models
from torchvision.datasets import ImageFolder # Use standard ImageFolder
from PIL import Image
from collections import deque
import random
import time
import wandb
import logging
from tqdm.notebook import tqdm # Use notebook version for Kaggle UI
import argparse # Keep for structure, but values often fixed/from wandb
import copy # Needed to save best model state
from wandb.sdk.wandb_settings import Settings # For timeout setting

# =============================================================================
# W&B Login (Using Kaggle Secrets or Environment Variable)
# =============================================================================
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
    wandb.login(key=wandb_api_key)
    print("W&B login successful using Kaggle Secret.")
except ImportError:
    print("kaggle_secrets not found. Ensure it's enabled or use interactive/env var login.")
    wandb.login()
except Exception as e:
     print(f"W&B login using Kaggle Secret failed: {e}. Trying other methods.")
     wandb.login()

In [None]:
 =============================================================================
# Configuration (Define fixed settings and default HPARAMS here)
# =============================================================================
# --- Fixed Settings ---
# <<< VERIFY THIS PATH to your dataset input in Kaggle >>>
DATA_DIR = "/kaggle/input/inaturalist-12k/inaturalist_12K"
SEED = 42
IMG_SIZE = 224 # ResNet usually uses 224
NUM_WORKERS = 2 # Kaggle typical limit
VAL_SPLIT = 0.2
WANDB_PROJECT_NAME = "CNN-iNaturalist-Scratch" # <<< USE SAME PROJECT AS PART A
WANDB_ENTITY = None # Optional: Let wandb infer or set "user_or_team_name"
OUTPUT_DIR = "/kaggle/working/output_partB" # Save outputs here in Kaggle
MODEL_SAVE_NAME = "resnet50_finetuned_best.pth"

# --- Default Hyperparameters (will be overridden by W&B sweep config) ---
DEFAULT_HPARAMS = {
    'finetune_strategy': 'feature_extract',
    'unfreeze_layers': 1, # K for unfreeze_last_k
    'learning_rate': 0.001,
    'weight_decay': 0,
    'batch_size': 32,
    'epochs': 15,
    'optimizer': 'adam', # Default optimizer
    'seed': SEED # Include seed in config logged
}

# --- Setup Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', stream=sys.stdout)

# =============================================================================
# Utility Functions and Classes
# =============================================================================

# --- From utils_b.py ---
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed); torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    logging.debug(f"Seed set to {seed}")

def seed_worker(worker_id):
    worker_seed = (torch.initial_seed()) % 2**32
    np.random.seed(worker_seed); random.seed(worker_seed)
    logging.debug(f"Worker {worker_id} seeded with {worker_seed}")

g_dataloader_seed_b = torch.Generator()

In [None]:
# --- Helper Class for Applying Transforms to Subsets ---
class TransformedSubset(Dataset):
    """ Applies a transform to a Subset. """
    def __init__(self, subset, transform=None):
        self.subset = subset; self.transform = transform
    def __getitem__(self, index):
        x, y = self.subset[index] # Gets PIL Image from ImageFolder subset
        if self.transform: x = self.transform(x)
        return x, y
    def __len__(self): return len(self.subset)

# --- From dataset_b.py ---
def get_data_loaders_b(data_dir, batch_size=32, val_split=0.2, num_workers=2, img_size=224, seed=42, generator=None):
    """ Creates DataLoaders for fine-tuning ResNet50. """
    logging.info(f"DataLoaders Part B: batch={batch_size}, val_split={val_split}, workers={num_workers}")
    if generator is None: generator = torch.Generator().manual_seed(seed)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    target_input_size = img_size
    train_transform = transforms.Compose([transforms.RandomResizedCrop(target_input_size, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.1, contrast=0.1), transforms.ToTensor(), normalize])
    val_test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(target_input_size), transforms.ToTensor(), normalize])
    train_dir, test_dir = os.path.join(data_dir, 'train'), os.path.join(data_dir, 'test')
    if not os.path.isdir(train_dir): raise FileNotFoundError(f"Train dir missing: {train_dir}")
    if not os.path.isdir(test_dir): raise FileNotFoundError(f"Test dir missing: {test_dir}")
    try:
        full_train_dataset = ImageFolder(root=train_dir); test_dataset_pil = ImageFolder(root=test_dir)
    except Exception as e: logging.error(f"Dataset load error: {e}"); raise e
    targets = np.array([s[1] for s in full_train_dataset.samples]); dataset_size = len(targets); train_indices=[]; val_indices=[]; val_loader=None; num_classes=len(full_train_dataset.classes)
    valid_indices = list(range(dataset_size)) # Assume all are valid initially
    if 0<val_split<1 and dataset_size>=2 and num_classes>0:
        local_rng=np.random.RandomState(seed); indices_by_class={lbl:[] for lbl in range(num_classes)}
        for idx in valid_indices: indices_by_class[targets[idx]].append(idx)
        for label,indices in indices_by_class.items():
            n_cls=len(indices);
            if n_cls==0: continue
            local_rng.shuffle(indices); n_val=int(np.floor(val_split*n_cls))
            if n_cls>1 and n_val==n_cls: n_val=n_cls-1
            elif n_cls<=1 and val_split>0: n_val=0
            val_indices.extend(indices[:n_val]); train_indices.extend(indices[n_val:])
        if not train_indices or not val_indices: logging.warning("Split fail. Random split."); local_rng.shuffle(valid_indices); split_point=int(len(valid_indices)*(1-val_split)); train_indices=valid_indices[:split_point]; val_indices=valid_indices[split_point:]
        logging.info(f"Split (seed {seed}): {len(train_indices)} train, {len(val_indices)} val.")
        local_rng.shuffle(train_indices)
        train_subset_pil = Subset(full_train_dataset, train_indices); val_subset_pil = Subset(full_train_dataset, val_indices)
        train_dataset_transformed = TransformedSubset(train_subset_pil, transform=train_transform)
        val_dataset_transformed = TransformedSubset(val_subset_pil, transform=val_test_transform)
        val_loader = DataLoader(val_dataset_transformed, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available(), worker_init_fn=seed_worker if num_workers>0 else None, generator=generator if num_workers>0 else None)
    else: logging.warning("Val split skipped."); train_indices=valid_indices; train_subset_pil = Subset(full_train_dataset, train_indices); train_dataset_transformed = TransformedSubset(train_subset_pil, transform=train_transform)
    test_dataset_transformed = TransformedSubset(test_dataset_pil, transform=val_test_transform)
    use_persistent_workers=num_workers>0 and sys.version_info>=(3,8); loader_kwargs={'persistent_workers':True,'prefetch_factor':2} if use_persistent_workers else {}
    train_loader=DataLoader(train_dataset_transformed,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=torch.cuda.is_available(),drop_last=True,worker_init_fn=seed_worker if num_workers>0 else None,generator=generator if num_workers>0 else None,**loader_kwargs)
    test_loader=DataLoader(test_dataset_transformed,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=torch.cuda.is_available(),worker_init_fn=seed_worker if num_workers>0 else None,generator=generator if num_workers>0 else None,**loader_kwargs)
    return train_loader, val_loader, test_loader, full_train_dataset.classes

In [None]:
def set_parameter_requires_grad(model, strategy, num_unfreeze_layers=1):
    """ Sets requires_grad based on fine-tuning strategy. """
    trainable_params = 0; total_params = sum(p.numel() for p in model.parameters())
    if strategy == 'feature_extract':
        logging.info("Strategy: Feature Extraction")
        for param in model.parameters(): param.requires_grad = False
        if hasattr(model, 'fc') and isinstance(model.fc, nn.Linear):
            for param in model.fc.parameters(): param.requires_grad = True
            fc_params = sum(p.numel() for p in model.fc.parameters())
            logging.info(f" -> Unfrozen 'fc' ({fc_params:,} params).")
            trainable_params = fc_params
        else: logging.warning("Could not find 'fc' layer.")
    elif strategy == 'finetune_all':
        logging.info("Strategy: Fine-tuning All Layers")
        for param in model.parameters(): param.requires_grad = True
        trainable_params = total_params
    elif strategy == 'unfreeze_last_k':
        k = num_unfreeze_layers; logging.info(f"Strategy: Unfreezing last {k} block(s) + classifier")
        for param in model.parameters(): param.requires_grad = False
        fc_params = 0
        if hasattr(model, 'fc') and isinstance(model.fc, nn.Linear):
            for param in model.fc.parameters(): param.requires_grad = True
            fc_params = sum(p.numel() for p in model.fc.parameters())
            logging.info(f" -> Unfrozen 'fc' ({fc_params:,} params).")
        layer_names = [name for name,_ in model.named_children() if name.startswith('layer')]
        k = max(0, min(k, len(layer_names))); layers_to_unfreeze = layer_names[-k:]
        unfrozen_block_params = 0
        if layers_to_unfreeze:
            logging.info(f" -> Unfreezing blocks: {layers_to_unfreeze}")
            for name, child in model.named_children():
                if name in layers_to_unfreeze:
                    for param in child.parameters(): param.requires_grad = True
                    block_params = sum(p.numel() for p in child.parameters())
                    unfrozen_block_params += block_params; logging.info(f"    - Unfrozen '{name}' ({block_params:,} params).")
        trainable_params = fc_params + unfrozen_block_params
    else: raise ValueError(f"Unknown strategy: {strategy}")
    logging.info(f"Model Parameters: Trainable={trainable_params:,} / Total={total_params:,}")
    return trainable_params, total_params

In [None]:
# =============================================================================
# Main Training Function (for W&B Agent)
# =============================================================================
def train_finetune_trial():
    """ Trains one fine-tuning trial based on wandb.config. """
    run = None
    config = None
    best_model_state = None
    best_val_acc = 0.0
    best_epoch = 0
    # Initialize metrics
    epoch_val_acc = float('nan'); epoch_val_loss = float('nan')
    epoch_train_acc = 0.0; epoch_train_loss = float('nan')
    epochs_run_actual = 0

    try:
        # --- Initialize W&B Run ---
        run = wandb.init(settings=Settings(init_timeout=300))
        # Use defaults if sweep doesn't provide a value (more robust)
        config = wandb.config
        hparams = {**DEFAULT_HPARAMS, **dict(config)} # Merge defaults with sweep config

        # --- Set Run Name ---
        try:
            strategy_tag = hparams['finetune_strategy'][:4]
            if hparams['finetune_strategy'] == 'unfreeze_last_k':
                 strategy_tag += f"k{hparams['unfreeze_layers']}"
            run_name = (f"{strategy_tag}_bs{hparams['batch_size']}"
                        f"_lr{hparams['learning_rate']:.1E}_wd{hparams['weight_decay']:.1E}")
            run_name = run_name.replace('.','p').replace('-','').replace('E','e')
            wandb.run.name = run_name[:128]
        except Exception as name_e:
             logging.warning(f"Could not set dynamic run name: {name_e}")
             wandb.run.name = f"finetune-{run.id[:8]}" # Fallback name

        logging.info(f"--- Starting Trial: {wandb.run.name} (ID: {run.id}) ---")
        logging.info(f"Effective Config: {hparams}") # Log the actual merged config used

        # --- Setup ---
        trial_seed = hparams['seed'] # Use seed from config
        set_seed(trial_seed)
        g_dataloader_seed_b.manual_seed(trial_seed)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        logging.info(f"Using device: {device}")

        # --- Data ---
        train_loader, val_loader, test_loader, classes = get_data_loaders_b(
            data_dir=DATA_DIR, batch_size=hparams['batch_size'], augment=True, # Assuming always augment
            num_workers=NUM_WORKERS, img_size=IMG_SIZE, val_split=VAL_SPLIT, seed=trial_seed,
            generator=g_dataloader_seed_b
        )
        num_classes = len(classes)

        # --- Model ---
        logging.info("Loading pre-trained ResNet50 model...")
        weights = models.ResNet50_Weights.IMAGENET1K_V1
        model = models.resnet50(weights=weights)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes) # Replace head
        logging.info(f"Replaced final layer for {num_classes} classes.")

        # --- Apply Fine-tuning Strategy ---
        trainable_params, total_params = set_parameter_requires_grad(
            model, hparams['finetune_strategy'], hparams['unfreeze_layers']
        )
        model = model.to(device)

        # --- Log Model Info ---
        wandb.watch(model, log='gradients', log_freq=100)
        # Update config logged to W&B with actual trainable params
        wandb.config.update({
            "trainable_parameters": trainable_params,
            "total_parameters": total_params,
            "num_classes": num_classes
        }, allow_val_change=True)
        logging.info(f"Trainable Params: {trainable_params:,}, Total Params: {total_params:,}")


        # --- Loss, Optimizer, Scheduler ---
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=hparams['learning_rate'],
                               weight_decay=hparams['weight_decay'])
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=False)

        # --- Training Loop ---
        epochs_to_run = hparams['epochs']
        logging.info(f"Starting training for {epochs_to_run} epochs...")

        for epoch in range(epochs_to_run):
            epochs_run_actual = epoch + 1; epoch_start_time = time.time()
            # --- Training Epoch ---
            model.train(); running_loss = 0.0; correct_train = 0; total_train = 0
            train_pbar = tqdm(train_loader, desc=f"Ep {epoch+1} Tr", leave=False, file=sys.stdout)
            for i, batch_data in enumerate(train_pbar):
                 try: inputs, labels = batch_data; assert not torch.any(labels < 0)
                 except: logging.warning(f"Skip bad train batch {i}"); continue
                 inputs, labels = inputs.to(device), labels.to(device)
                 try:
                    optimizer.zero_grad(set_to_none=True); outputs = model(inputs)
                    loss = criterion(outputs, labels); loss.backward(); optimizer.step()
                    running_loss += loss.item() * inputs.size(0) # Accumulate total loss
                    _, predicted = outputs.max(1)
                    total_train += labels.size(0); correct_train += predicted.eq(labels).sum().item()
                    if i % 20 == 0 and total_train > 0: train_pbar.set_postfix({'L': f'{running_loss/total_train:.3f}', 'Acc': f'{100.*correct_train/total_train:.1f}%'})
                 except RuntimeError as e: logging.error(f"Runtime error: {e}"); torch.cuda.empty_cache(); raise e
            train_pbar.close()
            epoch_train_loss = running_loss/total_train if total_train>0 else 0
            epoch_train_acc = 100.*correct_train/total_train if total_train>0 else 0

            # --- Validation Epoch ---
            epoch_val_loss = float('nan'); epoch_val_acc = float('nan') # Reset for epoch
            if val_loader and len(val_loader) > 0:
                model.eval(); val_correct = 0; val_total = 0; val_loss_accum = 0.0
                val_pbar = tqdm(val_loader, desc=f"Ep {epoch+1} Val", leave=False, file=sys.stdout)
                with torch.no_grad():
                    for i_val, batch_data in enumerate(val_pbar):
                        try: inputs, labels = batch_data; assert not torch.any(labels < 0)
                        except: logging.warning(f"Skip bad val batch {i_val}"); continue
                        inputs, labels = inputs.to(device), labels.to(device)
                        try:
                            outputs = model(inputs); loss = criterion(outputs, labels)
                            val_loss_accum += loss.item() * inputs.size(0) # Accumulate total loss
                            _, predicted = outputs.max(1)
                            val_total += labels.size(0); val_correct += predicted.eq(labels).sum().item()
                            if i_val % 10 == 0 and val_total > 0: val_pbar.set_postfix({'L': f'{val_loss_accum/val_total:.3f}', 'Acc': f'{100.*val_correct/val_total:.1f}%'})
                        except RuntimeError as e: logging.error(f"Val forward error: {e}")
                val_pbar.close()
                if val_total > 0:
                    epoch_val_loss = val_loss_accum / val_total # Average loss per sample
                    epoch_val_acc = 100. * val_correct / val_total
                    scheduler.step(epoch_val_acc) # Step based on validation accuracy

            # --- Logging ---
            epoch_duration = time.time() - epoch_start_time
            log_dict = {'epoch': epoch+1, 'train_loss': epoch_train_loss, 'train_accuracy': epoch_train_acc, 'lr': optimizer.param_groups[0]['lr'], 'epoch_sec': epoch_duration}
            if not np.isnan(epoch_val_loss): log_dict['val_loss'] = epoch_val_loss
            if not np.isnan(epoch_val_acc): log_dict['val_accuracy'] = epoch_val_acc
            wandb.log(log_dict, commit=True) # Log at end of epoch
            logging.info(f'E{epoch+1}/{epochs_to_run}|Tr L:{epoch_train_loss:.3f},Tr Acc:{epoch_train_acc:.2f}%|'+(f'Val L:{epoch_val_loss:.3f},Val Acc:{epoch_val_acc:.2f}%|' if not np.isnan(epoch_val_acc) else 'Val:N/A|')+f'LR:{optimizer.param_groups[0]["lr"]:.1E}|T:{epoch_duration:.1f}s')

            # --- Save Best Model State IN MEMORY ---
            if not np.isnan(epoch_val_acc) and epoch_val_acc > best_val_acc:
                best_val_acc = epoch_val_acc; best_epoch = epoch + 1
                best_model_state = copy.deepcopy(model.state_dict())
                logging.info(f"*** Best val acc: {best_val_acc:.2f}% at Ep {best_epoch} ***")

            if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear cache end of epoch

        # --- End of Trial Summary Logging ---
        final_best_val_acc = best_val_acc if best_epoch > 0 else (epoch_val_acc if not np.isnan(epoch_val_acc) else 0)
        final_best_epoch = best_epoch if best_epoch > 0 else epochs_run_actual
        wandb.run.summary["best_val_accuracy"] = final_best_val_acc
        wandb.run.summary["best_epoch"] = final_best_epoch
        wandb.run.summary["final_train_accuracy"] = epoch_train_acc
        wandb.run.summary["final_val_loss"] = epoch_val_loss
        wandb.run.summary["final_val_accuracy"] = epoch_val_acc
        wandb.run.summary["epochs_completed"] = epochs_run_actual

        logging.info(f"--- Trial {wandb.run.name} Finished. Best Val Acc: {final_best_val_acc:.2f}% ---")

        # --- Optional: Save best model locally and as artifact ---
        if best_model_state:
            save_path = os.path.join(OUTPUT_DIR, f"best_model_{run.id}.pth")
            os.makedirs(OUTPUT_DIR, exist_ok=True)
            try:
                save_dict = {'epoch': best_epoch, 'model_state_dict': best_model_state, 'best_val_accuracy': best_val_acc, 'config': hparams}
                torch.save(save_dict, save_path)
                logging.info(f"Best model state saved locally to {save_path}")
                model_artifact = wandb.Artifact(f"best-model-{run.id}", type="model", description=f"Best ResNet50 fine-tuned model from run {run.name}", metadata=hparams)
                model_artifact.add_file(save_path)
                wandb.log_artifact(model_artifact)
                logging.info("Best model logged as W&B artifact.")
            except Exception as e: logging.error(f"Failed to save/log best model state: {e}")

    except Exception as e:
         logging.error(f"Error during training trial {run.id if run else 'unknown'}: {e}", exc_info=True)
         if run and wandb.run:
             try: wandb.log({"error": str(e)[:1024]}, commit=True)
             except: pass
             wandb.finish(exit_code=1)

    finally:
        # Ensure finish is called
        if run and wandb.run is not None and wandb.run.id == run.id:
             if hasattr(wandb.run, 'finished') and not wandb.run.finished:
                 try: wandb.finish()
                 except Exception as fe: logging.error(f"Error finishing W&B run: {fe}")

In [None]:
# =============================================================================
# Sweep Configuration Reference (for Part B sweep 'b49z6yly')
# =============================================================================
sweep_config_part_b_ref = {
    'method': 'bayes',
    'metric': {'name': 'best_val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'finetune_strategy': {'values': ['feature_extract', 'unfreeze_last_k', 'finetune_all']},
        'unfreeze_layers': {'values': [1, 2, 3]},
        'learning_rate': {'distribution': 'log_uniform_values', 'min': 1e-5, 'max': 5e-3},
        'weight_decay': {'values': [0, 0.0001, 0.0005]},
        'batch_size': {'values': [16, 32]},
        'epochs': {'value': 15},
        'seed': {'value': 42}
        # --- Add other hyperparameters from your ACTUAL sweep config if different ---
        # e.g., 'optimizer': {'values': ['adam']}
        # e.g., 'hidden_size': {'values': [256]} # If Critic/Actor size was swept (unlikely needed)
    }
}
print("Sweep Configuration Reference (Agent uses actual config from W&B):")
# import json; print(json.dumps(sweep_config_part_b_ref, indent=2)) # Pretty print

# =============================================================================
# Start the W&B Agent for Part B Sweep
# =============================================================================
print(f"\n--- Starting W&B Agent for Fine-tuning Sweep: {SWEEP_ID} ---")
print(f"--- Agent will run max {AGENT_RUN_COUNT if AGENT_RUN_COUNT else 'unlimited'} trials ---")

try:
    # Run the agent, calling train_finetune_trial for each run configuration
    wandb.agent(SWEEP_ID, function=train_finetune_trial, count=AGENT_RUN_COUNT, project=WANDB_PROJECT_NAME, entity=WANDB_ENTITY)
except KeyboardInterrupt:
    logging.warning("Sweep agent stopped manually (KeyboardInterrupt).")
except Exception as e:
    logging.error(f"W&B Agent execution stopped due to error: {e}", exc_info=True)

print("--- W&B Agent Finished ---")