In [4]:
import random
import shutil
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models
from torchvision.models import ResNet18_Weights
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import time

from pathlib import Path
from PIL import Image

ImportError: DLL load failed while importing _C: The specified module could not be found.

In [None]:
# Set path variables
IMG_DIR = Path('./spectograms')
PLOT_DIR = Path('./plots')
PLOT_DIR.mkdir(exist_ok=True)
DEL_DIR = IMG_DIR / 'deleted'
NUM_IMG = 0

In [None]:
# Utility Functions

def split_train_and_valid(img_dir, image_types, random_seed=0):
  random.seed(0)
  valid_indices = random.sample(range(NUM_IMG), NUM_IMG//5)
  image_keywords = [child.name for child in img_dir.iterdir() if child.is_dir()]
  for typ in image_keywords:
    typ_dir = img_dir / typ
    train_dir = img_dir / 'train' / typ
    test_dir = img_dir / 'test' / typ
    train_dir.mkdir(parents=True, exist_ok=True)
    test_dir.mkdir(parents=True, exist_ok=True)
    img_files = list(typ_dir.rglob('*.png'))
    valid_imgs = [img_files[i] for i in valid_indices]
    for fn in valid_imgs:
      shutil.move(fn, test_dir/fn.name)
    img_files = list(typ_dir.rglob('*.png'))
    for fn in img_files:
      shutil.move(fn, train_dir/fn.name)
    os.rmdir(typ_dir)

def save_fig_with_date(figname):
  plt.savefig(PLOT_DIR/f"{figname}_{time.ctime().split(' ')[-2].replace(':', '')}.png")

def file_name_with_date(filename):
  filename = Path(filename)
  return f"{filename.stem}_{'_'.join(time.ctime().split(' ')[1:4])}{filename.suffix}"

def plot_random_sampled_images(img_dir, ncols=4, nrows=5,random_seed=0):
  list_of_image_files = list(img_dir.rglob("*.png"))
  random.seed(random_seed)
  random.shuffle(list_of_image_files)

  # prepare empty canvas
  plt.figure(figsize=(ncols*5,nrows*4))

  # for each coloumn and row, load image and plot image on the canvas
  for i in range(ncols*nrows):
    fname = list_of_image_files[i]
    image = Image.open(fname)
    plt.subplot(nrows,ncols, i+1) 
    plt.imshow(image)
    plt.title('/'.join(str(fname.parent).split('/')[-2:]))
    plt.axis('off')
  save_fig_with_date("dataset_check") # save figure as png in PLOT_DIR

### Download or Unzip
- If this is your first time to download the dataset, it will automatically download the images fro duckduckgo search
- If you already ran the crawling code and have `image_data.zip` file in your local or google drive, you can upload it to the current Colab storage and unzip it

In [None]:
image_types = configure_image_categories()
print(image_types)
assert len(image_types)==4, "The length of image types has to be 4"
assert all( isinstance(typ, str) for typ in image_types), "Every element of image_types has to be string"

['telecaster', 'stratocaster', 'fender jaguar', 'jazzmaster']


In [None]:
split_train_and_valid(IMG_DIR, image_types) # Split train and valid into different directories

In [None]:
plot_random_sampled_images(IMG_DIR, ncols=4, nrows=5, random_seed=1)

### Dataset Class


In [None]:
class ImageSet:
  def __init__(self, path_dir, file_types=['jpg', 'png'], transform=None):
    self.path = Path(path_dir)
    self.image_fns = sorted(item for y in [list(self.path.rglob(f'*.{x}')) for x in file_types] for item in y) 
    self.classes = sorted(list(set([x.parent.name for x in self.image_fns])))
    self.cls2idx = {k: i for i, k in enumerate(self.classes)}
    self.transform = transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
  
  def __len__(self):
    return len(self.image_fns)
  
  def __getitem__(self, idx):
    img_path = self.image_fns[idx]
    img = self.transform(Image.open(img_path).convert('RGB'))
    cls = img_path.parent.name
    return img, self.cls2idx[cls]

trainset = ImageSet(IMG_DIR/'train')
testset = ImageSet(IMG_DIR/'test')

Check the Batch result

In [None]:
tensor2pil = transforms.Compose([
    transforms.Normalize(mean=[0, 0, 0], std=[4.3668, 4.4643, 4.4444]),
    transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
    transforms.ToPILImage()
])

def show_batch(dataloader, ncols=4, nrows=5, random_seed=0):
  torch.manual_seed(random_seed)
  images, labels = next(iter(dataloader))
  plt.figure(figsize=(ncols*5,nrows*4))
  for i in range(ncols*nrows):
    plt.subplot(nrows, ncols, i+1)
    pil_img = tensor2pil(images[i])
    plt.imshow(pil_img)
    plt.title(trainset.classes[labels[i]])
    plt.axis('off')

train_loader = DataLoader(trainset, batch_size=32, shuffle=True)
test_loader = DataLoader(testset, batch_size=32, shuffle=False)

In [None]:
show_batch(train_loader, random_seed=0)
save_fig_with_date("train_batch")

### Trainer Class

In [None]:
class Trainer:
  def __init__(self, model, train_loader, valid_loader, model_name='resnet'):
    self.model = model
    self.train_loader = train_loader
    self.valid_loader = valid_loader
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.model.to(self.device)
    self.criterion = nn.NLLLoss()
    self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
    self.best_loss = np.inf
    self.best_acc = 0.0
    self.train_losses = []
    self.valid_losses = []
    self.train_accs = []
    self.valid_accs = []
    self.model_name = model_name

  def validation(self):
    self.model.eval()
    current_loss = 0
    num_total_correct_pred = 0
    with torch.inference_mode():
      for batch in self.valid_loader:
        images, labels = batch
        images, labels = images.to(self.device), labels.to(self.device)
        outputs = self.model(images)
        probs = torch.softmax(outputs, dim=-1)
        log_probs = torch.log(probs)

        loss = self.criterion(log_probs, labels)
        predicted_classes = torch.argmax(outputs, dim=-1)
        num_acc_pred = (predicted_classes == labels.to(self.device)).sum()
        num_total_correct_pred += num_acc_pred.item()

        current_loss += loss.item() * len(labels)
    mean_loss = current_loss / len(self.valid_loader.dataset)
    mean_acc = num_total_correct_pred / len(self.valid_loader.dataset)
    return mean_loss, mean_acc

  def train_by_number_of_epochs(self, num_epochs):
    for epoch in tqdm(range(num_epochs)):
      self.model.train()
      for batch in tqdm(self.train_loader, leave=False):
        images, labels = batch
        images, labels = images.to(self.device), labels.to(self.device)
        self.optimizer.zero_grad()
        outputs = self.model(images)
        probs = torch.softmax(outputs, dim=-1)
        log_probs = torch.log(probs)
        loss = self.criterion(log_probs, labels) 

        acc = (torch.argmax(outputs, dim=-1) == labels.to(self.device)).sum() / len(labels)
        loss.backward()
        self.optimizer.step()

        self.train_losses.append(loss.item())
        self.train_accs.append(acc.item())

      # Validation
      valid_loss, valid_acc = self.validation()
      if valid_acc > self.best_acc:
        self.best_acc = valid_acc
        models_parameters = self.model.state_dict()
        print(f"Saving best model at epoch {len(self.valid_accs)}, acc: {valid_acc}")
        torch.save(models_parameters, f'{self.model_name}_best.pt')

      self.valid_losses.append(valid_loss)
      self.valid_accs.append(valid_acc)

    # Plot Accuracy curve
    plt.plot(self.train_accs)
    plt.plot(range(len(self.train_loader)-1, len(self.train_accs), len(self.train_loader)) ,self.valid_accs)
    plt.title("Accuracy")

class Interpreter:
  def __init__(self, model, test_loader):
    self.model = model
    self.loader = test_loader
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.criterion = nn.NLLLoss(reduction='none')
    self.test_losses, self.test_preds = self.infer_loader()

  def infer_loader(self):
    self.model.to(self.device)
    self.model.eval()
    losses = torch.zeros(len(self.loader.dataset))
    predicted_classes = torch.zeros(len(self.loader.dataset))
    current_idx = 0
    with torch.inference_mode():
      for batch in self.loader:
        imgs, labels = batch
        imgs, labels = imgs.to(self.device), labels.to(self.device)
        outputs = self.model(imgs)
        log_probs = torch.log(torch.softmax(outputs, dim=-1))
        loss = self.criterion(log_probs, labels)
        losses[current_idx:current_idx+len(labels)] = loss.cpu()
        predicted_classes[current_idx:current_idx+len(labels)] = torch.argmax(outputs, dim=-1).cpu()
        current_idx += len(labels)
    predicted_classes = predicted_classes.long().tolist()
    return losses, predicted_classes

  def plot_top_losses(self, topk=10):
    top_losses, top_idx = torch.topk(self.test_losses, k=topk)
    fig, axes = plt.subplots(topk, 1, figsize=(5, 5*topk))
    class_names = self.loader.dataset.classes
    for i in range(topk):
      img, label = self.loader.dataset[top_idx[i]]
      axes[i].imshow(tensor2pil(img))
      axes[i].set_title(f"True: {class_names[label]} /  Pred: {class_names[self.test_preds[top_idx[i]]]} / Loss: {top_losses[i]:.4f}, ")
      axes[i].axis('off')

  def plot_confusion_matrix(self):
    class_names = self.loader.dataset.classes
    confusion_matrix = torch.zeros(len(class_names), len(class_names))
    for i in range(len(self.loader.dataset)):
      confusion_matrix[self.loader.dataset[i][1], self.test_preds[i]] += 1
    confusion_matrix = confusion_matrix.numpy()
    plt.figure(figsize = (10,7))
    plt.imshow(confusion_matrix)
    plt.xticks(range(len(class_names)), class_names, rotation=90)
    plt.yticks(range(len(class_names)), class_names)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    for (x, y), value in np.ndenumerate(confusion_matrix):
      plt.text(y, x, f"{int(value)}", va="center", ha="center")


### Train with randomly initialized weights

In [None]:
model = torchvision.models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(trainset.classes))
trainer = Trainer(model, train_loader, test_loader, 'without_pretraining')
trainer.train_by_number_of_epochs(num_epochs=60)
save_fig_with_date("model_without_pretraining")

interpreter = Interpreter(model, test_loader)
interpreter.plot_top_losses(topk=10)
save_fig_with_date("model_without_pretraining_top_losses")
interpreter.plot_confusion_matrix()
save_fig_with_date("model_without_pretraining_confusion_mat")

In [None]:
# Plot loss curve
plt.plot(trainer.train_losses)
plt.plot(range(len(trainer.train_loader)-1, len(trainer.train_accs), len(trainer.train_loader)) , trainer.valid_losses)
plt.title("Loss")
plt.ylim([0,3])

### Train with pre-trained weights


In [None]:
# Use the same ResNet18 architecture but using pre-trained weights
# as initial weights
# The weight was trained with ImageNet 1K
model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, len(trainset.classes))
trainer = Trainer(model, train_loader, test_loader, 'with_pretraining')
trainer.train_by_number_of_epochs(num_epochs=60)
save_fig_with_date("model_pretrained")

# the model was trained with 15 epochs
# But the best validation accuracy was achieved in epoch 11 (12th epoch)

interpreter = Interpreter(model, test_loader)
interpreter.plot_top_losses(topk=10)
save_fig_with_date("model_pretrained_top_losses")
interpreter.plot_confusion_matrix()
save_fig_with_date("model_pretrained_confusion_mat")

#### Load saved model

In [None]:
saved_weight = torch.load('with_pretraining_best.pt')
model.load_state_dict(saved_weight) # load trained weights parameters to the model

interpreter = Interpreter(model, test_loader)
interpreter.plot_top_losses(topk=10)
save_fig_with_date("model_pretrained_best_epoch_top_losses")
interpreter.plot_confusion_matrix()
save_fig_with_date("model_pretrained_best_epoch_confusion_mat")

In [None]:

trainset = ImageSet(IMG_DIR/'train')
testset = ImageSet(IMG_DIR/'test')
train_loader = DataLoader(trainset, batch_size=32, shuffle=True)
test_loader = DataLoader(testset, batch_size=32, shuffle=False)

model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, len(trainset.classes))
trainer = Trainer(model, train_loader, test_loader, 'after_cleaning')
trainer.train_by_number_of_epochs(num_epochs=15)
save_fig_with_date("model_after_cleaning")
model.load_state_dict(torch.load('after_cleaning_best.pt'))
interpreter = Interpreter(model, test_loader)
interpreter.plot_top_losses(topk=10)
save_fig_with_date("model_after_cleaning_top_losses")
interpreter.plot_confusion_matrix()
save_fig_with_date("model_after_cleaning_confusion_mat")

## Test with custom images

In [None]:
! rm -rf custom_images/

In [None]:
# YOU HAVE TO UPLOAD YOUR JPG IMAGE TO custom_images/ TO RUN THIS CELL
custom_image_dir = Path("custom_images/") # Path to directory containing custom images
custom_image_dir.mkdir(exist_ok=True, parents=True)

custom_image_fns = list(custom_image_dir.glob("*.jpg")) + list(custom_image_dir.glob("*.png"))

model.cpu()
model.eval()
plt.figure(figsize=(5, 5*len(custom_image_fns)))
for i, fn in enumerate(custom_image_fns):
  plt.subplot(len(custom_image_fns), 1, i+1)
  img = Image.open(fn)
  plt.imshow(img)
  img_tensor = testset.transform(img)
  pred = model(img_tensor.unsqueeze(0))
  pred_class = trainset.classes[pred.argmax()]
  plt.title(f"Prediction: {pred_class} / Confidence: {pred.softmax(dim=-1).max():.3f}")
  plt.axis("off")
save_fig_with_date("custom_images_predictions")
