In [1]:
import time
import pickle
import numpy as np
import pandas as pd
from gmm import *
import matplotlib.pyplot as plt
from multiprocessing import Pool
from collections import defaultdict
from scipy.stats import multivariate_normal as mvn
from sklearn.model_selection import train_test_split

plt.rcParams["font.size"] = 18
plt.rcParams["axes.grid"] = True
plt.rcParams["figure.figsize"] = 8,6
plt.rcParams['font.serif'] = "Cambria"
plt.rcParams['font.family'] = "serif"

%load_ext autoreload
%autoreload 2

In [2]:
df = pd.read_csv("../datasets/2B/coast/train_0.csv", header=None)
X = df.to_numpy()
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,13,14,15,16,17,18,19,20,21,22
0,0.9808,0.98211,0.41813,1.0148e-11,2.5903e-10,3.1159e-07,-0.29073,-0.39855,-1.3882,0.000529,...,0.000428,0.00766,0.003981,0.018343,1.6161,1.572,1.8232,3.2124,2.2108,2.6424
1,0.52013,0.28381,0.44948,2.3552e-08,1.1112e-05,2.319e-07,0.38395,0.7181,1.0005,0.000353,...,0.000176,0.005468,0.00194,0.027341,2.7736,2.6344,3.1931,1.3373,2.565,3.0839
2,0.68152,0.43121,0.30115,2.8956e-09,2.2498e-08,1.8187e-10,0.69114,0.058977,-0.051016,0.00068,...,0.000731,0.011112,0.0064,0.007131,1.7541,1.4404,1.016,2.807,2.6903,2.693
3,0.62619,0.48974,0.80199,4.1085e-12,2.8254e-08,7.8034e-09,0.60768,-0.49373,1.1772,0.000605,...,0.000731,0.009222,0.006299,0.011641,2.103,2.0801,1.5434,1.4591,3.0236,1.3988
4,0.60757,0.65997,0.86624,3.9775e-12,1.7394e-07,5.157e-08,0.22185,-1.2549,0.3049,0.000353,...,0.000605,0.005543,0.003981,0.022955,2.2475,1.4874,2.1702,2.3613,1.4043,1.6768


In [3]:
classes = ["coast", "highway", "mountain", "opencountry", "tallbuilding"]
for class_name in classes:
    gmm = GMM_vl(q=14, tol=1e-3)
    gmm.fit(class_name=class_name, epochs=30)
    
    fname = "results/" + class_name + ".pickle"
    fin = open(fname, "wb")
    pickle.dump(gmm, fin)
    fin.close()

100%|██████████| 36/36 [49:57<00:00, 83.27s/it]  
100%|██████████| 36/36 [11:38<00:00, 19.41s/it]
100%|██████████| 36/36 [1:07:20<00:00, 112.24s/it]
100%|██████████| 36/36 [1:12:47<00:00, 121.33s/it]
100%|██████████| 36/36 [27:56<00:00, 46.58s/it]


In [4]:
classes = ["coast", "highway", "mountain", "opencountry", "tallbuilding"]
class_wise_gmm = {}
for class_name in classes:
    fname = "results/" + class_name + ".pickle"
    fout = open(fname, "rb")
    gmm = pickle.load(fout)
    fout.close()
    
    class_wise_gmm[class_name] = gmm

In [5]:
gmm_cl = GMM_vl_classifier()
gmm_cl.compile(class_wise_gmm)
gmm_cl.transform()

100%|██████████| 36/36 [00:01<00:00, 35.68it/s]
100%|██████████| 36/36 [00:00<00:00, 78.26it/s]
100%|██████████| 36/36 [00:00<00:00, 76.31it/s] 
100%|██████████| 36/36 [00:00<00:00, 51.90it/s]
100%|██████████| 36/36 [00:00<00:00, 57.00it/s]
100%|██████████| 36/36 [00:00<00:00, 61.94it/s]
100%|██████████| 36/36 [00:00<00:00, 72.98it/s]
100%|██████████| 36/36 [00:00<00:00, 81.53it/s]
100%|██████████| 36/36 [00:00<00:00, 91.56it/s] 
100%|██████████| 36/36 [00:00<00:00, 101.86it/s]
100%|██████████| 36/36 [00:00<00:00, 95.79it/s]
100%|██████████| 36/36 [00:00<00:00, 104.44it/s]
100%|██████████| 36/36 [00:00<00:00, 58.72it/s]
100%|██████████| 36/36 [00:00<00:00, 54.76it/s]
100%|██████████| 36/36 [00:00<00:00, 57.48it/s]
100%|██████████| 36/36 [00:00<00:00, 57.05it/s]
100%|██████████| 36/36 [00:00<00:00, 61.49it/s]
100%|██████████| 36/36 [00:00<00:00, 56.70it/s]
100%|██████████| 36/36 [00:00<00:00, 54.48it/s]
100%|██████████| 36/36 [00:00<00:00, 57.75it/s]
100%|██████████| 36/36 [00:00<00:00,

In [6]:
gmm_cl.classify(eps=1e-3)
acc = []
total = []
classes_present = {"coast":0, "highway":1, "mountain":2, "opencountry":3, "tallbuilding":4}
classification_list = []

for i in gmm_cl.gmm_class_names:
    classification = np.argmax(gmm_cl.class_wise_df[i].to_numpy(), axis=1)
    actual_class = classes_present[i]
    total.append(classification.size)
    acc.append(np.sum(classification==actual_class)/total[-1])
    classification_list.append(classification)
    
for i in acc:
    print(i)

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


1.0
0.43956043956043955
0.31417624521072796
0.024390243902439025
0.7710843373493976


In [7]:
gmm_cl.class_wise_df["coast"]

Unnamed: 0,coast,highway,mountain,opencountry,tallbuilding
0,inf,1.982555e-111,4.971823e+16,6.634563e-25,6.313721e+99
1,inf,1.828361e+154,inf,inf,inf
2,inf,2.394975e-36,1.396985e+205,9.072771e+86,4.746419e+136
3,inf,8.346348e+31,inf,inf,inf
4,inf,1.797440e+21,3.708899e+231,inf,4.238122e+274
...,...,...,...,...,...
246,inf,4.818559e+151,inf,inf,inf
247,inf,2.021793e+15,2.871016e+198,8.475041e+206,4.828325e+228
248,inf,3.485264e+06,inf,inf,inf
249,inf,1.284193e+07,inf,inf,inf
