<a href="https://colab.research.google.com/github/wshuyi/info-5731-public/blob/master/demo_05_big_cats_image_classification_deep_transfer_learning_202104_unt_resnet18_CAM_simplified.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import fastai
fastai.__version__

In [None]:
!pip install -U fastai

In [None]:
!pip install torchcam

In [None]:
from torchvision.io.image import read_image
from torchvision.transforms.functional import normalize, resize, to_pil_image
from torchcam.cams import SmoothGradCAMpp
from torchcam.cams import SmoothGradCAMpp
import matplotlib.pyplot as plt
from torchcam.utils import overlay_mask

from sklearn.metrics import confusion_matrix, classification_report
import shutil
import os
from pathlib import Path
from fastai.vision.all import *

In [None]:
def get_labels_and_preds(predictions):
  df = pd.DataFrame(predictions[0])
  preds = (df[0] < df[1]).astype(int)
  labels = predictions[1]
  return labels, preds

def visualize_cam_on_img(img_name, model):
  cam_extractor = SmoothGradCAMpp(model)
  # Get your input
  img = read_image(str(img_name))

  # Preprocess it for your chosen model
  input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).cuda()

  # Preprocess your data and feed it to the model
  out = model(input_tensor.unsqueeze(0))
  # Retrieve the CAM by passing the class index and the model output
  activation_map = cam_extractor(out.squeeze(0).argmax().item(), out)


  # Resize the CAM and overlay it
  result = overlay_mask(to_pil_image(img), to_pil_image(activation_map, mode='F'), alpha=0.5)
  # Display it
  plt.imshow(result); plt.axis('off'); plt.tight_layout(); plt.show()

def rename_in_order(folder):
    files = list(folder.glob("*"))
    temp_folder = folder / "temp"
    temp_folder.mkdir()
    target_files = []
    for i in range(len(files)):
        source = files[i]
        target = temp_folder / f"{i+1:03}.jpg"
        target_files.append(target)
        shutil.move(source, target)
    for file in target_files:
        shutil.copy2(file, folder)
    shutil.rmtree(temp_folder)

def init_images_dir(images):
  if images.exists():
    shutil.rmtree(images)

  images.mkdir()

  train_folder = images / "train"
  valid_folder = images / "valid"
  test_folder = images / "test"

  train_folder.mkdir()
  valid_folder.mkdir()
  test_folder.mkdir()

def split_images_into_train_valid_test(mycls, images, images_original, train_ratio=0.7, test_ratio=0.1):

    files = list((images_original / mycls).glob("*.jpg"))

    train_length = round(len(files) * train_ratio)
    test_length = round(len(files) * test_ratio)

    train = files[:train_length]
    valid = files[train_length:len(files) - test_length]
    test = files[len(files) - test_length:]

    train_folder = images / "train"
    valid_folder = images / "valid"
    test_folder = images / "test"

    (train_folder / mycls).mkdir()
    (valid_folder / mycls).mkdir()
    (test_folder / mycls).mkdir()

    for file in train:
        shutil.copy(file, train_folder / mycls)

    for file in valid:
        shutil.copy(file, valid_folder / mycls)

    for file in test:
        shutil.copy(file, test_folder / mycls)
        
    rename_in_order(train_folder / mycls)
    rename_in_order(valid_folder / mycls)
    rename_in_order(test_folder / mycls)

In [None]:
!git clone https://github.com/wshuyi/big-cats-image-original.git

In [None]:
# make targets

In [None]:
images = Path('images')

In [None]:
init_images_dir(images)

In [None]:
images_original = Path("big-cats-image-original/")

In [None]:
for mycls in ["cheetah", "jaguar"]:
  split_images_into_train_valid_test(mycls, images, images_original)

In [None]:
len(list(images.glob("*/*/*.jpg")))

In [None]:
arch = resnet18
# run_date = "20210418"
metrics=[accuracy, error_rate]
item_tfms=Resize(224)
batch_tfms=aug_transforms()
seed=2
epochs = 20


In [None]:
working_dir = Path(".")

In [None]:
dls = ImageDataLoaders.from_folder(images, train='train', valid='valid', seed=seed, item_tfms=item_tfms, batch_tfms=batch_tfms)

In [None]:
dls.show_batch()

In [None]:
learn = cnn_learner(dls, arch, metrics=metrics)

In [None]:
# learn.lr_find()

In [None]:
base_lr = 3e-3

In [None]:
%%time
learn.fine_tune(epochs=epochs, base_lr=base_lr, cbs=[SaveModelCallback(), EarlyStoppingCallback()])

In [None]:
model_dump = working_dir/ f"fine_tuned"

In [None]:
learn.save(model_dump)

In [None]:
dls1 = ImageDataLoaders.from_folder(images, train='train', valid='test', seed=seed, item_tfms=item_tfms)

In [None]:
learn1 = cnn_learner(dls1, arch, metrics=metrics)

In [None]:
learn1.load(model_dump)

In [None]:
%%time
predictions = learn1.get_preds()

In [None]:
predictions

In [None]:
labels, preds = get_labels_and_preds(predictions)

In [None]:
print(classification_report(labels, preds))

In [None]:
print(confusion_matrix(labels, preds))

In [None]:
learn1.show_results()

In [None]:
interp = Interpretation.from_learner(learn1)

In [None]:
interp.plot_top_losses(9, figsize=(15,10))


In [None]:
#Visualize the arch

In [None]:
dummy_input = torch.randn(2, 3, 224, 224).cuda()

In [None]:
torch.onnx.export(learn1.model, dummy_input, "output.onnx", verbose=False)

In [None]:
model = learn1.model

In [None]:
img_name = images/"test/cheetah/001.jpg"
visualize_cam_on_img(img_name, model)

In [None]:
img_name = images/"test/cheetah/015.jpg"
visualize_cam_on_img(img_name, model)

In [None]:
img_name = images/"test/jaguar/001.jpg"
visualize_cam_on_img(img_name, model)

In [None]:
img_name = images/"test/jaguar/015.jpg"
visualize_cam_on_img(img_name, model)