In [1]:
import torch, torchvision

from pathlib import Path
import numpy as np
import cv2
import pandas as pd
from tqdm import tqdm
import PIL.Image as Image
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rc
from matplotlib.ticker import MaxNLocator
from torch.optim import lr_scheduler
from glob import glob
import shutil
from collections import defaultdict
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from torch import nn, optim

import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import models

from os import listdir
from os.path import isfile, join
from pathlib import Path

plt.rcParams["figure.figsize"] = (12,8)

%matplotlib inline
%config InlineBackend.figure_format='retina'

sns.set(style='whitegrid', palette='muted', font_scale=1.2)

HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]

sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cwd = Path().resolve()

In [2]:
def create_model(n_classes):
  model = models.resnet34(pretrained=False)

  n_features = model.fc.in_features
  model.fc = nn.Linear(n_features, n_classes)

  return model.to(device)

In [3]:
fashion_type_classes = 18
fashion_pattern_classes = 36
furniture_type_classes = 10
furniture_materials_classes = 7
color_classes=12

fashion_type_model = None
fashion_pattern_model = None
furniture_type_model = None
furniture_materials_model = None
color_model = None
###
fashion_type_model = create_model(fashion_type_classes)
fashion_type_model.load_state_dict(torch.load('best_model_state_fashion_type.bin'))

fashion_pattern_model = create_model(fashion_pattern_classes)
fashion_pattern_model.load_state_dict(torch.load('best_model_state_fashion_pattern.bin'))

furniture_type_model = create_model(furniture_type_classes)
furniture_type_model.load_state_dict(torch.load('best_model_state_furniture_type.bin'))

furniture_materials_model = create_model(furniture_materials_classes)
furniture_materials_model.load_state_dict(torch.load('best_model_state_furniture_materials.bin'))

color_model = create_model(color_classes)
color_model.load_state_dict(torch.load('best_model_state_color.bin'))

furniture_type_folder = "vision_data/furniture/type_data"
fashion_type_folder = "vision_data/fashion/type_data"
furniture_materials_folder = "vision_data/furniture/materials_data"
fashion_pattern_folder = "vision_data/fashion/pattern_data"
color_folder = "vision_data/color/color_data"

RuntimeError: Error(s) in loading state_dict for ResNet:
	size mismatch for fc.weight: copying a param with shape torch.Size([31, 512]) from checkpoint, the shape in current model is torch.Size([12, 512]).
	size mismatch for fc.bias: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([12]).

In [None]:
furniture_type_labels = ['AreaRug', 'Bed', 'Chair', 'CoffeeTable', 'CouchChair', 'EndTable', 'Lamp', 'Shelves', 'Sofa', 'Table']
furniture_materials_labels = ['leather', 'marble', 'memory foam', 'metal', 'natural fibers', 'wood', 'wool']
fashion_type_labels = ['blouse', 'coat', 'dress', 'hat', 'hoodie', 'jacket', 'jeans', 'joggers', 'shirt', 'shirt, vest',
 'shoes', 'skirt', 'suit', 'sweater', 'tank top', 'trousers', 'tshirt', 'vest']
fashion_pattern_labels = ['camouflage', 'canvas', 'cargo', 'checkered', 'checkered, plain', 'denim', 'design', 'diamonds', 'dotted',
 'floral', 'heavy stripes', 'heavy vertical stripes', 'holiday', 'horizontal stripes', 'knit', 'leafy design', 'leapard print',
 'leather', 'light spots', 'light stripes', 'light vertical stripes', 'multicolored', 'plaid', 'plain', 'plain with stripes on side',
 'radiant', 'spots', 'star design', 'streaks', 'stripes', 'text', 'twin colors', 'velvet', 'vertical design', 'vertical stripes',
 'vertical striples']
color_labels = ['beige', 'black', 'blue', 'brown', 'dark blue', 'dark brown', 'dark green', 'dark grey', 'dark pink', 'dark red',
                'dark violet', 'dark yellow', 'dirty green', 'dirty grey', 'golden', 'green', 'grey', 'light blue', 'light grey',
                'light orange', 'light pink', 'light red', 'maroon', 'olive', 'orange', 'pink', 'purple', 'red', 'violet', 
                'white', 'yellow']

In [None]:
# color_original_folder = "vision_data/color/color_original_data"
# color_type_set = sorted(listdir(color_original_folder))
# print(color_type_set)


In [None]:
image_models={
    "furniture_type": {
        "model" : furniture_type_model,
        "folder" : furniture_type_folder,
        "classes" : furniture_type_labels
    },
    "furniture_materials": {
        "model" : furniture_materials_model,
        "folder" : furniture_materials_folder,
        "classes" : furniture_materials_labels
    },
    "fashion_type": {
        "model" : fashion_type_model,
        "folder" : fashion_type_folder,
        "classes" : fashion_type_labels
    },
    "fashion_pattern": {
        "model" : fashion_pattern_model,
        "folder" : fashion_pattern_folder,
        "classes" : fashion_pattern_labels
    },
    "color": {
        "model" : color_model,
        "folder" : color_folder,
        "classes" : color_labels
    }
}

In [None]:
def imshow(inp, title=None):
  inp = inp.numpy().transpose((1, 2, 0))
  mean = np.array([mean_nums])
  std = np.array([std_nums])
  inp = std * inp + mean
  inp = np.clip(inp, 0, 1)
  plt.figure(figsize = (20,2))
  plt.imshow(inp)
  if title is not None:
    plt.title(title)
  plt.axis('off')
  plt.show()

In [None]:
def show_predictions(model, class_names, data_loader, n_images=6):
  model = model.eval()
  images_handeled = 0
  plt.figure()

  with torch.no_grad():
    for i, (inputs, labels) in enumerate(data_loader):
      inputs = inputs.to(device)
      labels = labels.to(device)

      outputs = model(inputs)
      _, preds = torch.max(outputs, 1)

      for j in range(inputs.shape[0]):
        images_handeled += 1
        ax = plt.subplot(n_images,1, images_handeled)
        ax.set_title(f'predicted: {class_names[preds[j]]} true: {class_names[labels[j]]}')
        imshow(inputs.cpu().data[j])
        ax.axis('off')

        if images_handeled == n_images:
          return

In [None]:
def get_predictions(model, data_loader):
  model = model.eval()
  predictions = []
  real_values = []
  with torch.no_grad():
    for inputs, labels in data_loader:
      inputs = inputs.to(device)
      labels = labels.to(device)

      outputs = model(inputs)
      _, preds = torch.max(outputs, 1)
      predictions.extend(preds)
      real_values.extend(labels)
  predictions = torch.as_tensor(predictions).cpu()
  real_values = torch.as_tensor(real_values).cpu()
  return predictions, real_values

In [None]:
def show_confusion_matrix(confusion_matrix, class_names):

  cm = confusion_matrix.copy()

  cell_counts = cm.flatten()

  cm_row_norm = cm / cm.sum(axis=1)[:, np.newaxis]

  row_percentages = ["{0:.2f}".format(value) for value in cm_row_norm.flatten()]

  cell_labels = [f"{cnt}\n{per}" for cnt, per in zip(cell_counts, row_percentages)]
  cell_labels = np.asarray(cell_labels).reshape(cm.shape[0], cm.shape[1])

  df_cm = pd.DataFrame(cm_row_norm, index=class_names, columns=class_names)

  hmap = sns.heatmap(df_cm, annot=cell_labels, fmt="", cmap="Blues")
  hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
  hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
  plt.ylabel('True Sign')
  plt.xlabel('Predicted Sign');

In [None]:
mean_nums = [0.485, 0.456, 0.406]
std_nums = [0.229, 0.224, 0.225]

transform = T.Compose([
  T.Resize(size=[256,256]),
  T.CenterCrop(size=256),
  T.ToTensor(),
  T.Normalize(mean_nums, std_nums)
])

In [None]:
def make_predictions(data_type):
    image_dataset = ImageFolder(f'{image_models[data_type]["folder"]}/val/',transform) 
    data_loader = DataLoader(image_dataset, batch_size=4, shuffle=True, num_workers=4) 
    class_names = image_models[data_type]["classes"]
    show_predictions(image_models[data_type]["model"], class_names, data_loader,n_images=6) ########################## 3
#     y_pred, y_test = get_predictions(image_models[data_type]["model"], data_loader)
#     print(len(y_pred))
#     print(len(class_names))
#     print(classification_report(y_test, y_pred, target_names=class_names))
#     cm = confusion_matrix(y_test, y_pred,labels = class_names)
#     show_confusion_matrix(cm, class_names)

In [None]:
make_predictions("fashion_type") # "furniture_type", "furniture_materials", "fashion_type", "fashion_pattern", "color"

In [None]:
# inputs, classes = next(iter(data_loader))
# out = torchvision.utils.make_grid(inputs)

# imshow(out, title=[class_names[x] for x in classes])

In [None]:
def predict_proba(model, image_path):
  img = Image.open(image_path)
  img = img.convert('RGB')
  img = transform(img).unsqueeze(0)

  print(img)

  pred = model(img.to(device))
  pred = F.softmax(pred, dim=1)
  _, pred_label = torch.max(pred, 1)

  return pred.detach().cpu().numpy().flatten(), pred_label

In [None]:
pred, pred_label = predict_proba(image_models["fashion_type"]["model"], 'samples/dress.jpeg')


In [None]:
print(image_models["fashion_type"]["classes"][pred_label])

In [None]:
def show_prediction_confidence(prediction, class_names):
  pred_df = pd.DataFrame({
    'class_names': class_names,
    'values': prediction
  })
  sns.barplot(x='values', y='class_names', data=pred_df, orient='h')
  plt.xlim([0, 1]);

In [4]:
show_prediction_confidence(pred, image_models["fashion_type"]["classes"])

NameError: name 'show_prediction_confidence' is not defined

In [153]:
print(pred)

[7.0630126e-03 4.3446526e-02 1.6059436e-02 3.4127440e-02 6.0117621e-02
 2.0684065e-02 5.9490796e-04 6.0783437e-04 1.8854591e-01 5.8841822e-03
 1.3563641e-05 1.1707161e-05 1.1822331e-03 8.2772769e-02 5.1073563e-05
 1.8214837e-03 5.3529340e-01 1.7228044e-03]


In [154]:
print(pred)

[7.0630126e-03 4.3446526e-02 1.6059436e-02 3.4127440e-02 6.0117621e-02
 2.0684065e-02 5.9490796e-04 6.0783437e-04 1.8854591e-01 5.8841822e-03
 1.3563641e-05 1.1707161e-05 1.1822331e-03 8.2772769e-02 5.1073563e-05
 1.8214837e-03 5.3529340e-01 1.7228044e-03]
