In [None]:
%matplotlib inline
# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import torchsample
from torchsample import transforms as ts_transforms
import matplotlib.pyplot as plt
import time
import copy
import os
from PIL import Image
from sklearn.manifold import TSNE
import seaborn as sns

#from torchsample.transforms import RangeNorm

import functions.fine_tune as ft
import functions.stats as st

plt.ion()   # interactive mode

In [None]:
#model_ft = models.resnet18(pretrained=True).cuda()
model_ft=torch.load('./saved_models/resnet_real_multisoft_1300_100epoch_July19  13:56:48')
#model_ft=torch.load('./Obsolete/saved_models/resnet18_multi_88_real_7_15_17.mdl')
model_ft.train(False)
print(model_ft)

In [None]:
dataset='real'
if(dataset=='real'):
    rootdir='//home//mtezcan//Documents//Hoarding//Data_Sets//pruned//good//train//'  
    tr_dirs, tr_cir, tr_house, tr_room = st.subsetCreator(rootdir,multi_dir=False,im_per_room=100,roomdirs=[''])

    rootdir='//home//mtezcan//Documents//Hoarding//Data_Sets//pruned//good//val//'  
    test_dirs, test_cir, test_house, test_room = st.subsetCreator(rootdir,im_per_room=100,multi_dir=False,roomdirs=[''])
elif(dataset=='synthetic'):
    rootdir='//media//mtezcan//New Volume//HoardingImages//_rated//'
    tr_dirs, tr_cir, tr_house, tr_room = st.subsetCreator(rootdir,im_per_room=10)
    rootdir='//media//mtezcan//New Volume//HoardingImages//_val//'
    test_dirs, test_cir, test_house, test_room = st.subsetCreator(rootdir,im_per_room=40)
else:
    raise ValueError('Undefined dataset '+dataset)

print('Created random directories')
print(len(tr_dirs))
print(len(test_dirs))

In [None]:
tr_smax=st.extractFeats(tr_dirs,model_ft,outsize=9)
test_smax=st.extractFeats(test_dirs,model_ft,outsize=9)
print(tr_smax.shape)
print(test_smax.shape)

In [None]:
tr_pred=np.argmax(tr_smax,axis=1)+1
test_pred=np.argmax(test_smax,axis=1)+1
print(tr_pred.shape)
print(test_pred.shape)
tr_err=np.abs(tr_cir-tr_pred)
test_err=np.abs(test_cir-test_pred)
#print(test_pred[test_cir==2])
print('Training CIR-1 is '+str(np.mean(tr_err<=1)))
print('Test CIR-1 is '+str(np.mean(test_err<=1)))

In [None]:
print('Test CIR-1 is '+str(np.mean(test_err[:900]<=1)))
print('Test CIR-1 is '+str(np.mean(test_err[900:1800]<=1)))
print('Test CIR-1 is '+str(np.mean(test_err[1800:]<=1)))

In [None]:

plt.hist(test_pred,bins=range(1,11))


In [None]:
cir0,cir1,cir2=st.class_based_cirs(test_cir,test_pred)
print(cir1)

In [None]:
plt.figure(figsize=(15,10))
sns.boxplot(tr_cir,tr_pred)
plt.title('Synthetic Training Set')
plt.xlabel('Ground Truth')
plt.ylabel('Predictions')

plt.figure(figsize=(15,10))
sns.boxplot(test_cir,test_pred)
plt.title('Synthetic Validation Set')
plt.xlabel('Ground Truth')
plt.ylabel('Predictions')

In [None]:
sns.jointplot(tr_cir,tr_pred,marginal_kws=dict(bins=range(1,11), rug=True))
plt.title('Joint plot for training')
plt.figure()
sns.jointplot(test_cir,test_pred,marginal_kws=dict(bins=range(1,11), rug=True))
plt.title('Joint plot for validation')

In [None]:
model_params= list(model_ft.children())
network = nn.Sequential(*list(model_params)[:-1])
print(network)
tr_fvec=st.extractFeats(tr_dirs,network,batchsize=8,outsize=2048)
test_fvec=st.extractFeats(test_dirs,network,batchsize=8,outsize=2048)

print('Constructed Features')

model = TSNE(n_components=2, random_state=0)
np.set_printoptions(suppress=True)
tr_tsne=model.fit_transform(tr_fvec) 
test_tsne=model.fit_transform(test_fvec)

print('Constructed t-SNE embeddings')

In [None]:
tsne_resnet=test_tsne.copy()

In [None]:
plt.figure(figsize=(20,8))
plt.subplot(121)
plt.scatter(tr_tsne[:,0], tr_tsne[:,1], c=tr_cir, alpha=0.5)#,cmap='jet')
plt.colorbar()
plt.title('Training')
plt.subplot(122)
plt.scatter(test_tsne[:,0], test_tsne[:,1], c=test_cir, alpha=0.5)#,cmap='jet')
plt.colorbar()
plt.title('Validation')

In [None]:
st.plotAll(tr_tsne,tr_cir,tr_house,tr_room,data_title='Training')
st.plotAll(test_tsne,test_cir,test_house,test_room,data_title='Test')

In [None]:
st.plotAll(tr_tsne,tr_cir,tr_house,tr_room,data_title='Training')
st.plotAll(test_tsne,test_cir,test_house,test_room,data_title='Test')