In [None]:
import torch
from torch import nn
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
import torchvision
import torchvision.transforms.v2 as v2
import os
import matplotlib.pyplot as plt
import numpy as np
import lightning as pl
from torchmetrics import Accuracy
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer
import wandb
import numpy as np
import gc
import sys
sys.path.append(os.path.abspath("src"))
from src.dataloader import iNat_dataset
from src.config import *
from src.models import pretrained_light, Unfreeze_after_epochs



In [None]:
os.environ['WANDB_API_KEY'] = "API-KEY"    # your API-KEY to be entered here
wandb.login(key=os.getenv("WANDB_API_KEY"))

In [None]:
# Getting train, validation, and test dataloaders
dataset = iNat_dataset(data_dir=data_dir, augmentation = aug, batch_size=batch_size, NUM_WORKERS=NUM_WORKERS)
train_dataloader, val_dataloader, test_dataloader, classes, n_classes = dataset.load_dataset()

In [None]:
# Finetuning the model
finetune_model = pretrained_light(optim = optim, n_classes = n_classes, lr = lr)
callback = Unfreeze_after_epochs(unfreeze_at_epoch = unfreeze_at_epoch)
logger= WandbLogger(project= project_name, name = run_name, log_model = False)
trainer = pl.Trainer(
                        devices=1,
                        accelerator="auto",
                        precision="16-mixed",
                        gradient_clip_val=1.0,
                        max_epochs=epochs,
                        logger=logger,
                        profiler=None,
                        callbacks = [callback]
                    )

trainer.fit(finetune_model, train_dataloader, val_dataloader)
trainer.test(finetune_model, dataloaders=test_dataloader)
trainer.save_checkpoint("finetuned_model.ckpt")
wandb.finish()