# Chickpea Fusarium Disease Detection - Training
## SqueezeNet + CBAM Experiments

This notebook trains lightweight CNN models for Fusarium wilt disease detection in chickpeas.

In [6]:
# Cell 1: Setup - Clone repo and add to path
import subprocess
import sys
import os
import shutil

# Clone the repository
REPO_PATH = '/kaggle/working/chickpea-fusarium'

# Delete old clone to get fresh code
if os.path.exists(REPO_PATH):
    print("Removing old clone...")
    shutil.rmtree(REPO_PATH)

print("Cloning repository...")
subprocess.run([
    "git", "clone", "--depth", "1",
    "https://github.com/tklwin/chickpea-fusarium.git",
    REPO_PATH
], check=True)
print(f"✓ Repository cloned to {REPO_PATH}")

# Add to Python path
if REPO_PATH not in sys.path:
    sys.path.insert(0, REPO_PATH)
    
print(f"✓ Path setup complete")
print(f"  sys.path[0]: {sys.path[0]}")

Removing old clone...
Cloning repository...


Cloning into '/kaggle/working/chickpea-fusarium'...


✓ Repository cloned to /kaggle/working/chickpea-fusarium
✓ Path setup complete
  sys.path[0]: /kaggle/working/chickpea-fusarium


In [2]:
# Cell 2: Install dependencies
!pip install -q albumentations wandb

print("✓ Dependencies installed")

✓ Dependencies installed


In [7]:
# Cell 3: Test imports
try:
    from src.data.split import create_splits
    from src.data import get_dataloaders
    from src.models import get_model
    from src.training import Trainer
    from configs.default import get_config
    print("✓ All imports successful!")
except ImportError as e:
    print(f"✗ Import error: {e}")
    import traceback
    traceback.print_exc()



✓ All imports successful!


In [8]:
# Cell 4: Configuration
EXPERIMENT_NAME = "squeezenet_baseline_v1"

CONFIG = {
    # Data
    "data_dir": "/kaggle/input/fusarium-wilt-disease-in-chickpea-dataset/FUSARIUM-22/dataset_raw",
    "splits_dir": "/kaggle/working/splits",
    
    # Model
    "model_name": "squeezenet1_1",  # Options: squeezenet1_1, squeezenet1_1_cbam, mobilenetv2, efficientnet_b0
    "pretrained": True,
    "dropout": 0.5,
    
    # Training
    "batch_size": 32,
    "epochs": 50,
    "learning_rate": 1e-3,
    "weight_decay": 1e-4,
    "optimizer": "adamw",
    "scheduler": "cosine",
    
    # Class imbalance
    "use_class_weights": True,
    "use_weighted_sampler": False,
    
    # W&B
    "wandb_enabled": True,
    "wandb_entity": "tklwin_msds",
    "wandb_project": "chickpea",
    
    # Reproducibility
    "seed": 42,
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

Configuration:
  data_dir: /kaggle/input/fusarium-wilt-disease-in-chickpea-dataset/FUSARIUM-22/dataset_raw
  splits_dir: /kaggle/working/splits
  model_name: squeezenet1_1
  pretrained: True
  dropout: 0.5
  batch_size: 32
  epochs: 50
  learning_rate: 0.001
  weight_decay: 0.0001
  optimizer: adamw
  scheduler: cosine
  use_class_weights: True
  use_weighted_sampler: False
  wandb_enabled: True
  wandb_entity: tklwin_msds
  wandb_project: chickpea
  seed: 42


In [9]:
# Cell 5: Create data splits (run once)
splits = create_splits(
    data_dir=CONFIG["data_dir"],
    output_dir=CONFIG["splits_dir"],
    seed=CONFIG["seed"]
)

print("\n✓ Data splits created!")

Total images found: 4339

Original class distribution:
original_class
1(HR)     959
3(R)     1177
5(MR)    1133
7(S)      558
9(HS)     512
Name: count, dtype: int64

Merged class distribution:
label
0    2136
1    1133
2    1070
Name: count, dtype: int64

Split Statistics:

Train: 3037 images (70.0%)
  Class distribution:
    Class 0: 1495 (49.2%)
    Class 1: 793 (26.1%)
    Class 2: 749 (24.7%)

Val: 651 images (15.0%)
  Class distribution:
    Class 0: 320 (49.2%)
    Class 1: 170 (26.1%)
    Class 2: 161 (24.7%)

Test: 651 images (15.0%)
  Class distribution:
    Class 0: 321 (49.3%)
    Class 1: 170 (26.1%)
    Class 2: 160 (24.6%)

✓ Splits saved to: /kaggle/working/splits

✓ Data splits created!


In [10]:
# Cell 6: Setup W&B
import wandb
import os

# Load API key from Kaggle Secrets
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
    os.environ["WANDB_API_KEY"] = wandb_api_key
    print("✓ W&B API key loaded from Kaggle Secrets")
except Exception as e:
    print(f"⚠ Could not load W&B API key: {e}")
    print("  Disabling W&B logging...")
    CONFIG["wandb_enabled"] = False

✓ W&B API key loaded from Kaggle Secrets


In [11]:
# Cell 7: Build config and create trainer
config = get_config(
    data={
        "data_dir": CONFIG["data_dir"],
        "splits_dir": CONFIG["splits_dir"],
        "batch_size": CONFIG["batch_size"],
    },
    model={
        "model_name": CONFIG["model_name"],
        "pretrained": CONFIG["pretrained"],
        "dropout": CONFIG["dropout"],
    },
    training={
        "epochs": CONFIG["epochs"],
        "learning_rate": CONFIG["learning_rate"],
        "weight_decay": CONFIG["weight_decay"],
        "optimizer": CONFIG["optimizer"],
        "scheduler": CONFIG["scheduler"],
        "use_class_weights": CONFIG["use_class_weights"],
        "use_weighted_sampler": CONFIG["use_weighted_sampler"],
        "seed": CONFIG["seed"],
        "checkpoint_dir": "/kaggle/working/checkpoints",
    },
    wandb={
        "enabled": CONFIG["wandb_enabled"],
        "entity": CONFIG["wandb_entity"],
        "project": CONFIG["wandb_project"],
        "run_name": EXPERIMENT_NAME,
        "tags": [CONFIG["model_name"], "baseline"],
    }
)

print("✓ Config built")

✓ Config built


In [12]:
# Cell 8: Create trainer and start training
trainer = Trainer(config)
trainer.fit()

Using device: cuda
Downloading: "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth" to /root/.cache/torch/hub/checkpoints/squeezenet1_1-b8a52dc0.pth


100%|██████████| 4.73M/4.73M [00:00<00:00, 82.5MB/s]
  original_init(self, **validated_kwargs)
  A.CoarseDropout(



Model: squeezenet1_1
  Total params: 724,035
  Trainable params: 724,035
  Size: 2.76 MB

DataLoaders created:
  Train: 3037 images, 94 batches
  Val: 651 images, 21 batches
  Test: 651 images, 21 batches

Class weights (inverse_freq):
  Class 0: 0.6146 (count: 1495)
  Class 1: 1.1587 (count: 793)
  Class 2: 1.2267 (count: 749)

Using weighted CrossEntropyLoss


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Currently logged in as: [33mtklwin[0m ([33mtklwin_msds[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



Starting Training


Epoch 1/50 [Train]: 100%|██████████| 94/94 [00:42<00:00,  2.20it/s, loss=1.0948, acc=29.92%]
Epoch 1/50 [Val]: 100%|██████████| 21/21 [00:09<00:00,  2.32it/s]



Epoch 1/50
  Train Loss: 1.7338 | Train Acc: 29.92%
  Val Loss: 1.0986 | Val Acc: 49.16% | Val F1: 32.40%
  LR: 1.00e-03
  ✓ New best model saved! (Acc: 49.16%)


Epoch 2/50 [Train]: 100%|██████████| 94/94 [00:33<00:00,  2.80it/s, loss=1.0942, acc=43.58%]
Epoch 2/50 [Val]: 100%|██████████| 21/21 [00:06<00:00,  3.15it/s]



Epoch 2/50
  Train Loss: 1.0989 | Train Acc: 43.58%
  Val Loss: 1.0986 | Val Acc: 49.16% | Val F1: 32.40%
  LR: 1.00e-03


Epoch 3/50 [Train]: 100%|██████████| 94/94 [00:34<00:00,  2.75it/s, loss=1.0996, acc=37.80%]
Epoch 3/50 [Val]: 100%|██████████| 21/21 [00:06<00:00,  3.17it/s]



Epoch 3/50
  Train Loss: 1.0986 | Train Acc: 37.80%
  Val Loss: 1.0986 | Val Acc: 49.16% | Val F1: 32.40%
  LR: 1.00e-03


Epoch 4/50 [Train]:  48%|████▊     | 45/94 [00:16<00:18,  2.66it/s, loss=1.0986, acc=45.42%]


KeyboardInterrupt: 