In [1]:
%matplotlib inline


## YOGA ASANA CLASSIFICATION INFERENCE




In [2]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
import os
import copy
import glob
from PIL import Image
import pickle
import seaborn as sns
import pandas as pd

plt.ion()   # interactive mode

In [3]:
from cf_matrix import make_confusion_matrix

In [4]:
from tqdm.notebook import tqdm

## Dataset  

There data sets for asanas are very easily available to both - web scrape and download (seggregated datasets) free to experiment and explore.  

The images for the current model has been taken from an individual fellow [ML enthusiast](https://www.amarchenkova.com/2018/12/04/data-set-convolutional-neural-network-yoga-pose/) _(do check her post for some quantum computing too!)_   
There is also a bigger dataset available [here](https://oregonstate.box.com/s/4c5o6gilogogdm9m23tgtop7vw23e9vj),which can be used to further scale the project. However, a smaller dataset has been used for a quick prototype.  

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
class YogaInference():
    def __init__(self, model_paths, cn_path):
        
        with open(cn_path,"r") as op:
            self.class_names = [i.strip() for i in op.readlines()] 
            
        self.model_ft, self.model_conv = self.load_models(*model_paths)
        
        self.data_transforms =  transforms.Compose([
        transforms.Resize((512,512)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        
    def load_models(self, finetune_path, conv_feature_path):
        model_ft = models.resnet18(pretrained=True)
        num_ftrs = model_ft.fc.in_features
        # Here the size of each output is generalized
        model_ft.fc = nn.Linear(num_ftrs, len(self.class_names))
        model_ft = model_ft.to(device)
        model_ft.load_state_dict(torch.load(finetune_path,map_location=torch.device('cpu')))
        _ = model_ft.eval()
        
        
        model_conv = torchvision.models.resnet18(pretrained=True)
        for param in model_conv.parameters():
            param.requires_grad = False

        # Parameters of newly constructed modules have requires_grad=True by default
        num_ftrs = model_conv.fc.in_features
        model_conv.fc = nn.Linear(num_ftrs, len(self.class_names))

        model_conv.load_state_dict(torch.load(conv_feature_path,map_location=torch.device('cpu')))
        _ = model_conv.eval()

        return model_ft, model_conv
    
    
    @torch.no_grad()
    def validate_batch(self, test_path,model="fine_tune"):
        test_dataset = datasets.ImageFolder(test_path, self.data_transforms)
        test_dl = DataLoader(test_dataset, 4, num_workers=3, pin_memory=True)
        
        pred_batch_probs = []
        actual_batch_vals = []
        for xb, label in tqdm(test_dl):
            if model == "fine_tune":
                preds = self.model_ft(xb)
                
            if model == "conv_feature":
                preds = self.model_conv(xb)
                
            pred_batch_probs.append(preds)
            actual_batch_vals.append(label)
        
        pred_batch_probs = torch.cat(pred_batch_probs)
        
        #The below lines are to macth the validation output index with the actuals
        actual_batch_vals = torch.cat(actual_batch_vals)
        batch_classnames = [test_dl.dataset.classes[i] for i in actual_batch_vals]
        actual_batch_vals = [self.class_names.index(i) for i in batch_classnames]
        
        return torch.tensor(actual_batch_vals), pred_batch_probs
        
        
    @torch.no_grad()
    def get_output(self, image_names, model):
        output = None
        for idx, image_name in enumerate(image_names):
            image = Image.open(image_name)
            image = self.data_transforms(image).float()
            image = image.clone().detach()
            image = image.unsqueeze(0)
            if model == "fine_tune":
                preds = self.model_ft(image)
            
            if model == "conv_feature":
                preds = self.model_conv(image)

            if idx == 0:
                output = preds.numpy()
            else:
                output = np.vstack([output ,preds.numpy()])
        return output
    
    
    def classify_asana(self, images, model="fine_tune", batch=False, raw_out=False):
        if batch:
            output = self.get_output(images, model)
        else:
            output = self.get_output([images], model)
            
        asanas = []
        idx = np.argmax(output,axis=1)
        asanas = [self.class_names[i] for i in idx]
        
        if raw_out:
            return asanas, output
        return asanas, None

In [7]:
os.listdir("inference_images/batch_validation/")

['adhomukha_shwanasana', 'balasana', 'phalakasana', 'vrikshasana', 'tadasana']

In [8]:
model_paths = ["../models/finetuned_model.pth", "../models/convfeat_model.pth"]
classnames_path = "../models/class_names.txt"

In [9]:
yoga_inf = YogaInference(model_paths, classnames_path)

In [10]:
with open(classnames_path,"r") as op:
    class_names = [i.strip() for i in op.readlines()] 

In [11]:
images = glob.glob(os.path.join("inference_images/unseen_images/","*.jpeg"))

#### Unknown class verification

In [12]:
yoga_inf.classify_asana(images, model="fine_tune",batch=True)

(['sethu_bandhasana',
  'tadasana',
  'adhomukha_shwanasana',
  'phalakasana',
  'paschimottanasana'],
 None)

In [13]:
def get_top_3(file_name):
    out, predictions = yoga_inf.classify_asana(file_name, model="fine_tune",batch=False, raw_out=True)
    page = torch.nn.functional.softmax(torch.Tensor(predictions),dim=-1)
    top_prob, top_indices = torch.topk(page,3,axis=1)

    for i in range(3):
        print (class_names[top_indices[0][i].item()],top_prob[0][i].item())

In [14]:
get_top_3(images[0])

sethu_bandhasana 0.9932143688201904
balasana 0.005707908421754837
vrikshasana 0.00038166571175679564


#### Known class verification

In [15]:
# #data_transforms =  transforms.Compose([
# transforms.Resize((512,512)),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


In [16]:
# test_dataset = datasets.ImageFolder("inference_images/batch_validation/", data_transforms)
# test_dl = DataLoader(test_dataset, 4, num_workers=3, pin_memory=True)

In [17]:
# test_dl.dataset.classes

In [18]:
act, pred = yoga_inf.validate_batch("inference_images/batch_validation/", model="fine_tune")

HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))

  "Palette images with Transparency expressed in bytes should be "
  "Palette images with Transparency expressed in bytes should be "





In [27]:
def compute_confusion_matrix(true, pred):
    '''Computes a confusion matrix using numpy for two np.arrays
    true and pred.

    Results are identical (and similar in computation time) to: 
    "from sklearn.metrics import confusion_matrix"

    However, this function avoids the dependency on sklearn.'''

    K = len(np.unique(true)) # Number of classes 
    result = np.zeros((K, K))
    
    for i in range(len(true)):
        try:
            result[true[i]][pred[i]] += 1
        except IndexError:
            print ("The predicted value is outside the true value in the batch")
            continue

    return result

In [28]:
cm = compute_confusion_matrix(act.numpy(), torch.argmax(pred, axis=1).numpy())

The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the batch
The predicted value is outside the true value in the bat

In [29]:
cm

array([[70.,  0.,  0.,  0.,  0.],
       [ 1., 65.,  5.,  0.,  1.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0., 57.,  0.],
       [ 0.,  0.,  0.,  0.,  0.]])

In [None]:
categories = [class_names[i] for i in act.unique()]
make_confusion_matrix(cm, categories=categories,figsize=(10,10))

In [None]:
def analyse_preds(error_only=True):    
    page = torch.nn.functional.softmax(pred, dim=-1)
    top_prob, top_indices = torch.topk(page,3,axis=1)   
    
    if error_only:
        non_matched = torch.where(act!=top_indices[:,0])
    else:
        non_matched = list(range(len(page)))
        
    df = pd.DataFrame()
    df['actuals'] = [class_names[i.item()] for i in act[non_matched]]
    df['predicted_1'] = [class_names[i.item()] for i in top_indices[non_matched][:,0]]
    df['probs_1'] = top_prob[non_matched][:,0]
    
    df['predicted_2'] = [class_names[i.item()] for i in top_indices[non_matched][:,1]]
    df['probs_2'] = top_prob[non_matched][:,1]
    
    return df

In [None]:
df = analyse_preds(error_only=True)

In [None]:
df

#### Training Plots

In [None]:
import pickle

def analyse_training(pickle_path = "fine_tune_2.pkl"):
    with open("fine_tune_2.pkl", "rb") as op:
        x = pickle.load(op)
        
    df = pd.DataFrame(x)
    df.columns = ['phase', 'loss', 'accuracy']
    
    return df    

In [None]:
df = analyse_training(pickle_path="fine_tune.pkl")

In [None]:
def plot_metadata(df, info="accuracy"):
    val_scores = df[df["phase"] == "val"][info]
    train_scores = df[df["phase"] == "train"][info]
    plt.plot(val_scores,'r', label='validation')
    plt.plot(train_scores,'b', label='training')
    plt.xlabel('epoch')
    plt.ylabel(info)
    plt.legend()
    plt.title(info + ' vs. no. of epochs')

In [None]:
plot_metadata(df, info="accuracy")

In [None]:
plot_metadata(df, info="loss")