In [None]:
!git clone https://github.com/trinhdvt/PyTorch-DL-fundamental.git
%cd ./PyTorch-DL-fundamental/flower-classification

### Install dependencies

In [None]:
%%capture
!pip3 install -q gdown
!pip3 install -q torchsummary
!pip3 install -q wandb

In [None]:
!wandb login 6e6ad89112a80767ba11f981192ce27998246acb
import wandb

wandb.init(project="flower-classification", entity="trinhdvt")

### Download data

In [None]:
!gdown --id 18CQK_JXSgVny-fSYFPv6qqy_lYeBfm0U
!unzip -q data_2.zip
!rm -f data_2.zip

## Import library

In [None]:
import torch
import matplotlib.pyplot as plt

from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

from core import build_model, data_loader, train_helper
from utils import display_utils

from warnings import filterwarnings

filterwarnings("ignore", category=DeprecationWarning)
filterwarnings("ignore", category=FutureWarning)
filterwarnings("ignore", category=UserWarning)
torch.manual_seed(0)

%matplotlib inline

### Load data with PyTorch

In [None]:
train_dir = "./data_2/train/"
val_dir = "./data_2/valid/"
test_dir = "./data_2/test/"
params = {
    "batch_size": 32,
    "num_workers": 0,
    "pin_memory": True if torch.cuda.is_available() else False
}

class_names, train_loader, test_loader, val_loader = data_loader.load_data(
    train_dir,
    test_dir,
    val_dir,
    batch_size=params['batch_size'],
    pin_memory=params['pin_memory'])

print("Class: ", class_names)
for name, loader in [("Train set", train_loader),
                     ("Validation set", val_loader),
                     ("Testing set", (test_loader))]:
    print(f"{name}: {len(loader) * loader.batch_size} images")

### Visualize First Train Batch

In [None]:
for images, labels in train_loader:
    display_utils.display(images,
                          labels,
                          class_names,
                          save_path="train_batch.jpg")
    wandb.log({"First train batch": wandb.Image("train_batch.jpg")})
    break

### Define CNN Model

In [None]:
# summary(build_model.CNN(num_classes=5), (3, 224, 224), batch_size=1, device="cpu")

### Train model

In [None]:
model = build_model.CNN(num_classes=len(class_names),
                        input_size=224,
                        in_channels=3)

params.update({
    "epochs": 150,
    "lr": 0.001,
    "device": "cuda:0" if torch.cuda.is_available() else "cpu",
    "non_blocking": True if torch.cuda.is_available() else False,
})
params.update({
    "loss_fn": torch.nn.NLLLoss(),
    "optimzer": torch.optim.Adam(model.parameters(), lr=params['lr'])
})

# init tensorboard
tb = SummaryWriter()
grid = make_grid(data_loader.inv_normalize(images))
tb.add_image("images", grid)
tb.add_graph(model, images)
tb.close()

wandb.config = params

In [None]:
model.to(params['device'])
#
# wandb.watch(model)
train_hist = train_helper.traning_loops(epochs=params['epochs'],
                                        model=model,
                                        train_loader=train_loader,
                                        val_loader=val_loader,
                                        optimizer=params['optimzer'],
                                        criterion=params['loss_fn'],
                                        device=params['device'],
                                        non_blocking=params['non_blocking'])

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir runs

### Visualize Training results

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
axes = axes.flat
metrics_name = (("val_loss", "train_loss"), ("val_acc", "train_acc"))

for i, ax in enumerate(axes):
    metric = metrics_name[i]
    ax.plot(train_hist[metric[0]], label=metric[0])
    ax.plot(train_hist[metric[1]], label=metric[1])
    ax.set_xlabel("epoch")
    ax.legend()

plt.savefig("train_hist.jpg")
wandb.log({"train_hist": wandb.Image("train_hist.jpg")})
plt.show()

### Measuring Accuracy

In [None]:
train_helper.test_model(model,
                        test_loader=test_loader,
                        device=params['device'],
                        non_blocking=params['non_blocking'])

### Save model

In [None]:
torch.save(model.state_dict(), "last_model.pth")

### Show test result

In [None]:
model.eval()
with torch.no_grad():
    for imgs, labels in test_loader:
        # to device
        imgs = imgs.to(params['device'])
        labels = labels.to(params['device'])

        # get the output
        outputs = model(imgs)
        _, predicted = torch.max(outputs, dim=1)

        # plot results
        display_utils.display(images=imgs,
                              truth_labels=labels,
                              class_names=class_names,
                              predicted=predicted,
                              figsize=(10, 10),
                              save_path="test_result.jpg")
        wandb.log({"test_result": wandb.Image("test_result.jpg")})
        break