# Dragon Ball Classifier
An image classification model for Dragon Ball characters.

In [None]:
# Install dependencies

!pip uninstall "torchtext" "torchaudio"
!pip install "torch==1.4.0" "torchvision==0.5.0"
from fastai.vision import *
import os

# Data Preparation
Upload the `dragonball` folder in [`xyntechx/Dragon-Ball-Classifier`](https://github.com/xyntechx/Dragon-Ball-Classifier) to the `Files` sidepanel of Google Colab.

P.S. The model will only be trained to classify Vegeta, Goku, Gohan, and Trunks, so feel free to include the `.csv` files of image URLs of other characters in the `dragonball` folder.

In [2]:
characters = ["vegeta", "goku", "gohan", "trunks"]

In [3]:
# Create subfolders in /dragonball

for character in characters:
  path = Path("dragonball")
  dest = path/character
  dest.mkdir(parents=True, exist_ok=True)

In [None]:
# Download images for training

for character in characters:
  path = Path("dragonball")
  csv_file = Path("dragonball/" + character + ".csv")
  dest = path/character
  download_images(csv_file, dest, max_pics=100)

In [None]:
# Remove unopenable images

for character in characters:
  path = Path("dragonball")
  verify_images(path/character, delete=True, max_size=500)

# Model Training

In [None]:
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)
learner = cnn_learner(data, models.resnet101, metrics=accuracy)
learner.fit_one_cycle(20)

# Result Interpretation

In [None]:
interp = ClassificationInterpretation.from_learner(learner)
interp.plot_confusion_matrix()

# Model Exportation
A `.pkl` file will be created for you to (optionally) save.

In [21]:
learner.export()

# Model Usage
Upload the `test` folder in [`xyntechx/Dragon-Ball-Classifier`](https://github.com/xyntechx/Dragon-Ball-Classifier) to the `Files` sidepanel of Google Colab.

P.S. Unless you have modified the dataset and the above code to include more characters, the model will only be able to classify Vegeta, Goku, Gohan, and Trunks.

In [None]:
model = load_learner(path)

for i in range(4):
  img_path = "test/" + str(i+1) + ".png"
  img = open_image(img_path)
  pred_class, pred_idx, outputs = model.predict(img)
  
  print(pred_class)
  print(data.classes)
  print(outputs)