### Training a CNN from scratch

This notebook demonstrates the training routine of a custom CNN model for inaturalist dataset. Configurations need to be set in `src/config.py`

In [None]:
# Importing librraries
import torch
import os
import lightning as pl
from lightning.pytorch.loggers import WandbLogger
import wandb
import sys
sys.path.append(os.path.abspath("src"))
from src.dataloader import iNat_dataset
from src.models import CNN_light
from src.config import *
from src.utils import plot_sample_images, visualize_filters

In [None]:
# Logging in to wandb
os.environ['WANDB_API_KEY'] = "API-KEY"    # your API-KEY to be entered here
wandb.login(key=os.getenv("WANDB_API_KEY"))

In [None]:
# Getting train, validation, and test dataloaders
dataset = iNat_dataset(data_dir=data_dir, augmentation = aug, batch_size=batch_size, NUM_WORKERS=NUM_WORKERS)
train_dataloader, val_dataloader, test_dataloader, classes, n_classes = dataset.load_dataset()

In [None]:
# Model training, evaluation
model = CNN_light(optim= optim, filters= filters, kernel = kernel, pool_kernel=pool_kernel, pool_stride=pool_stride, batchnorm=batchnorm, activation=activation, dropout=dropout, ffn_size=ffn_size, n_classses = n_classes, lr=lr)
logger= WandbLogger(project= 'CNN_scratch', name = 'best_model', log_model = False)
trainer = pl.Trainer(
                        devices=1,
                        accelerator="gpu",
                        precision="16-mixed",
                        gradient_clip_val=1.0,
                        max_epochs=epochs,
                        logger=logger,
                        profiler=None,
                        
                    )

trainer.fit(model, train_dataloader, val_dataloader)
trainer.test(model, dataloaders=test_dataloader)
trainer.save_checkpoint("best_model.ckpt")
wandb.finish()



In [None]:
# Plotting sample images
plot_sample_images(test_dataloader, model)
visualize_filters(test_dataloader, model)
