### 2018/2019 - Task List 10

1. Implement Naive Bayes classifier with pyro
    - create apropriate parameters (mean and std for a and b, sigma - noise)
    - provide optimization procedure
    - check appropriateness of implemented method with selected dataset


# Required imports

In [1]:
%matplotlib inline
import pyro
import torch
import numpy as np
import matplotlib.pyplot as plt
import pyro.optim as optim
import pyro.distributions as dist
from torch.distributions import constraints
from tqdm import tqdm_notebook as tqdm
import seaborn as sns
from matplotlib import animation, rc
from IPython.display import HTML
import torch.nn as nn
from functools import partial
import pandas as pd
from pyro.contrib.autoguide import AutoDiagonalNormal
from pyro.infer import EmpiricalMarginal, SVI, Trace_ELBO, TracePredictive

In [2]:
pyro.set_rng_seed(1)
pyro.enable_validation(True)

## Solutions

In [3]:
def read_dataset():    
    data_seeds = pd.read_csv('./data_sets/seeds.data', sep="\t", header=None)
    data_seeds.columns = ["Area A", "Perimeter P", "Compactness C = 4*pi*A/P^2", "Length of kernel",
                      "Width of kernel", "Asymmetry coefficient", "Length of kernel groove", "decision"]
    return data_seeds.iloc[:, :-1].get_values(), data_seeds.iloc[:, -1:].get_values()

In [595]:
def pdf_normal(value, mean, std):
    # shouldn't happen
    if std == 0:
        return .001    
#     print('value ', type(value))
#     print('value ', type(mean))
#     print('value ', type(std))
    
    mean = torch.Tensor(mean.detach().numpy())
    std = torch.Tensor(std.detach().numpy())
#     print('value ', type(value))
#     print('value ', type(mean))
#     print('value ', type(std))
#     result = torch.tensor(1 / (std * np.sqrt(2 * np.pi)) * np.exp(- (value - mean) ** 2 / (2 * std ** 2)))
    result = 1 / (std * np.sqrt(2 * np.pi)) * np.exp(- (value - mean) ** 2 / (2 * std ** 2))
    return result

# return normalized probs for every class and ob
# return log probs for every class and ob
def norm_probs(values, decisions_mean_std):
    if len(decisions_mean_std) < 1:
        print("No values in decisions_mean_std!")
        return None
    
    # matrix of means, and another for std
    means, stds = decisions_mean_std[0], decisions_mean_std[1]

    decision_prob = {decision: 0 for decision in range(len(means[0]))}
    for attribute_idx in range(len(means)):
        # lists of mean and std for attribute
        mean, std = means[attribute_idx], stds[attribute_idx]
        for decision in range(len(mean)):
#             print('values', values)
            tmp = pdf_normal(values[attribute_idx], mean[decision], std[decision])
#             print('tmp', tmp)
            decision_prob[decision] += tmp    

    result = np.log(list(decision_prob.values()))
    return result
    
    
def model(data_x, data_y):    
    # decisions -> (mean, std)
    decisions_mean_std = []
    # go throught decisions - unique data_y
#     for decision in decisions_num:
    with pyro.plate("decisions", size=decisions_num) as decision:
        # we don't have to sample prob of class - skip it
        # dim: 1x2 -> 1 x decisions_num
        # decision_prob = pyro.sample("decision", dist.Categorical(decisions_wages))

        # (mean, std) for every attribute - it don't have to be list
        # attr_mean_std = []
        attr_indx = 0
        with pyro.plate("attributes", size=attr_num) as attr_indx:
    #     for attr_indx in range(attr_num):
            # todo work on it - it shouldn't be like this
            mean = torch.tensor(pyro.sample(f"mean-{decision}", pyro.distributions.Normal(3., 1.)))
            # todo check if it is ok to hold the abs outside sample
            std = torch.tensor(abs(pyro.sample(f"std-{decision}", pyro.distributions.Normal(1.5, 1.))))
            attr_mean_std = ((mean, std))
        
        decisions_mean_std = attr_mean_std
    
    probs_of_classes = torch.stack([torch.tensor(norm_probs(x, decisions_mean_std)) for x in data_x])
#     probs_of_classes[probs_of_classes != probs_of_classes] = 0.333
#     print('probs_of_classes', probs_of_classes)
#     print('probs_of_classes - type', probs_of_classes[0], 'type, ', type(probs_of_classes[0]))
    
    with pyro.plate("map", len(data_y)) as idx:
        # normalized probs of class, for every object
        # todo change probs to logits
        decisions_categorical = dist.Categorical(logits=probs_of_classes[idx])
#         print('probs_of_classes[idx]', probs_of_classes[idx])
#         print('decisions_categorical', decisions_categorical)
#         print('data_y', data_y)
        pyro.sample("obs", dist.Categorical(logits=probs_of_classes), obs=data_y.squeeze())

def guide(data_x, data_y):
    decisions_mean_std = []
#     for decision in decisions_num:
    with pyro.plate("decisions", size=decisions_num) as decision:
        # we don't have to sample prob of class - we can skip it
        # dim: 1x2 -> 1 x decisions_num
        # decision_prob = pyro.sample("decision", dist.Categorical(decisions_wages))
        
        # (mean, std) for every attribute
        attr_mean_std = []
        with pyro.plate("attributes", size=attr_num):
            mean_mean_param = pyro.param(f"mean-{decision}-mean", torch.tensor(10.))
            mean_std_param = pyro.param(f"mean-{decision}-std", torch.tensor(5.), constraint=constraints.positive)
            mean = torch.tensor(pyro.sample(f"mean-{decision}", pyro.distributions.Normal(mean_mean_param, mean_std_param)))
            
            std_mean_param = pyro.param(f"std-{decision}-mean", torch.tensor(0.))
            std_std_param = pyro.param(f"std-{decision}-std", torch.tensor(1.), constraint=constraints.positive)
            std = torch.tensor(abs(pyro.sample(f"std-{decision}", pyro.distributions.Normal(std_mean_param, std_std_param))))
            
            # todo don't need to append - just =
            attr_mean_std = (mean, std)
        
        decisions_mean_std = attr_mean_std
    return decisions_mean_std

def train(data_x, data_y):
    # some dimensions
    count = len(data_x)
    decisions_num = np.unique(data_y)
    attr_num = len(data_x[0])
    
    pyro.clear_param_store()
    num_iterations = 1000
#     model = create_probabilistic_model()
    optim = pyro.optim.Adam({"lr": 0.005})
    svi = pyro.infer.SVI(model, guide, optim, loss=pyro.infer.Trace_ELBO(), num_samples=count)
    t=tqdm(range(num_iterations))
    
    data_x_tensor, data_y_tensor = torch.Tensor(data_x), torch.Tensor(data_y)
    for j in t:
        loss = svi.step(data_x_tensor, data_y_tensor)
        t.set_postfix(loss=loss)
    return (model, svi)

In [596]:
data_x, data_y = read_dataset()

observations_num = len(data_x)
decisions_num = len(np.unique(data_y))
attr_num = len(data_x[0])

# print(decisions_num)
train(data_x[:5], data_y[:5])


# check parameters
# print("Check parameters:")
# for name, value in pyro.get_param_store().items():
#     print(name, pyro.param(name))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))





























































KeyboardInterrupt: 

In [578]:
tmp = torch.tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.]])
tmp.squeeze()

tensor([1., 1., 1., 1., 1.])

In [14]:
num_attr = 10
num_class = 3

mu = pyro.sample('mu_attr', dist.Normal(torch.zeros([num_class, num_attr]), torch.ones([num_class, num_attr])))
sigma = abs(pyro.sample('sigma_attr', dist.Normal(torch.zeros([num_class, num_attr]), torch.ones([num_class, num_attr]))))


In [16]:
def model(x_data, y_data):
    with pyro.plate("decisions", size=decisions_num) as _:
        pyro.sample('class', dist.Categorical(), obs=y_data)

In [None]:
for i in range(num_class):
    with pyro.plate("tst", size=num_attr) as _:
        print(pyro.sample('class', dist.Normal(1,2)))

In [None]:
lista[lista=!lista]
f"std--{clazz}"

In [481]:
pyro.param(f"qwe", torch.tensor(10.))
tst = pyro.sample('classzx', dist.Categorical(probs=torch.tensor([10.])), obs=[1,2,3])
print('tst', tst, type(tst))

tst [1, 2, 3] <class 'list'>


In [562]:
probs_of_classes_tst = torch.tensor([[-3.7433,  0.1131, -1.1246],
        [-2.9129,  0.2928, 1.0694],
        [-2.0371, -0.2617, -1.1063],
        [-2.1366, -0.6212, -1.1533],
        [-3.3112, -0.8562, -1.1786]])
print(type(probs_of_classes_tst[0][0]))
tmp = dist.Categorical(logits=probs_of_classes_tst).sample()
tmp
# dist.Categorical(probs=torch.tensor([0.1,0.1,0.3])).sample()

with pyro.plate("map", 5) as idx:
    # normalized probs of class, for every object
    # todo change probs to logits
    decisions_categorical = dist.Categorical(logits=probs_of_classes_tst[idx])
    print('probs_of_classes[idx]', probs_of_classes_tst[idx])
    print('decisions_categorical', decisions_categorical)
    tmp = pyro.sample("obs", dist.Categorical(logits=probs_of_classes_tst), obs=data_y[:5])
tmp

# with pyro.plate("map", 5):
#     # normalized probs of class, for every object
#     # todo change probs to logits
#     decisions_categorical = dist.Categorical(logits=probs_of_classes_tst)
#     pyro.sample("obs", decisions_categorical, obs=data_y[:5])

<class 'torch.Tensor'>
probs_of_classes[idx] tensor([[-3.7433,  0.1131, -1.1246],
        [-2.9129,  0.2928,  1.0694],
        [-2.0371, -0.2617, -1.1063],
        [-2.1366, -0.6212, -1.1533],
        [-3.3112, -0.8562, -1.1786]])
decisions_categorical Categorical(logits: torch.Size([5, 3]))


array([[1],
       [1],
       [1],
       [1],
       [1]])

In [156]:
probs_of_classes_for_test = torch.tensor([[-39.8658,  -3.4863,  -8.9305],
        [-39.7600,  -3.0195,  -8.9025],
        [-35.6060,  -1.9818,  -8.8773],
        [-28.9508,  -1.6515,  -8.8401],
        [-39.5270,  -2.9536,  -8.5969],
        [-39.4662,  -2.2626,  -8.9279],
        [-39.7617,  -2.9218,  -6.6563],
        [-36.0674,  -2.0222,  -8.9235],
        [-39.8271,  -3.2191,  -8.7254],
        [-39.6878,  -3.1633,  -8.6723],
        [-33.1179,  -3.0476,  -2.4864],
        [-38.6026,  -2.2107,  -9.0817],
        [-32.4254,  -1.8579,  -4.6747],
        [-34.2335,  -1.9926,  -8.6433],
        [-33.7778,  -1.9837,  -9.0515],
        [-37.0230,  -2.3126,  -3.8038],
        [-24.2996,  -1.1251,  -0.8649],
        [-39.5019,  -2.9176,  -8.6604],
        [-39.2818,  -2.1377,  -8.7241],
        [-15.3725,  -1.0560,  -4.1575],
        [-39.9980,  -2.8474,  -8.8211],
        [-39.8377,  -2.4680,  -9.1041],
        [-39.5750,  -3.0546,  -8.6696],
        [ -6.5187,  -0.6974,  -9.4485],
        [-39.9214,  -3.5234,  -9.0217],
        [-39.7203,  -3.4103,  -8.7842],
        [-21.8318,  -1.3911,  -7.7847],
        [-18.6371,  -1.2417,  -9.4137],
        [-39.1549,  -2.2473,  -9.0190],
        [-32.4265,  -1.9187,  -6.9680],
        [-24.0967,  -1.5136,  -9.3939],
        [-39.8499,  -3.4143,  -7.4894],
        [-39.4339,  -2.7620,  -4.9838],
        [-38.9434,  -2.2574,  -9.1514],
        [-39.7935,  -3.2119,  -8.9091],
        [-39.5624,  -3.1532,  -8.6984],
        [-39.8407,  -3.2849,  -8.6834],
        [-39.4800,  -2.6796,  -8.3522],
        [-39.7475,  -2.9697,  -8.5852],
        [-14.0868,  -0.7182,  -1.4417],
        [-25.2697,  -1.5439,  -9.1331],
        [-25.2697,  -1.5436,  -9.1400],
        [-14.7557,  -1.0297,  -9.0784],
        [-31.3278,  -2.6671,  -1.9797],
        [-39.5770,  -2.7458,  -8.3887],
        [-33.3246,  -1.9491,  -9.1447],
        [-39.7077,  -3.2034,  -8.8217],
        [-39.6847,  -2.9238,  -8.7152],
        [-39.7517,  -2.9761,  -8.9374],
        [-39.9015,  -3.3208,  -9.0039],
        [-38.9841,  -2.6594,  -4.7241],
        [-22.7583,  -1.6550,  -0.5052],
        [-37.7891,  -3.1992,  -4.0964],
        [-39.7338,  -2.4925,  -7.9306],
        [-40.0264,  -3.3593,  -9.2025],
        [-39.9204,  -3.5783,  -9.0667],
        [-39.7526,  -2.5123,  -8.8017],
        [-39.5562,  -2.6328,  -8.7963],
        [-39.7119,  -3.1534,  -8.7869],
        [-12.4082,  -0.9405,  -9.3145],
        [ -0.8270,  -0.5475,  -9.5308],
        [  0.6396,  -0.5939,  -9.5276],
        [ -5.7053,  -0.6615,  -8.4518],
        [-24.8757,  -1.5035,  -3.9224],
        [-15.3725,  -1.0784,  -9.3229],
        [-13.2660,  -0.9717,  -9.1928],
        [-39.8490,  -2.7305,  -9.0967],
        [-39.9514,  -2.5509,  -9.1402],
        [-39.8491,  -2.8047,  -9.1475],
        [-21.4648,  -1.3831,  -6.9801],
        [-38.1701,  -2.7842,  -4.2645],
        [-31.7073,  -2.6432,  -2.0820],
        [-33.1610,  -2.5683,  -2.4985],
        [-39.4779,  -2.0108,  -8.0653],
        [-38.7626,  -3.0582,  -4.5869],
        [-29.1709,  -2.4652,  -1.4502],
        [-39.7569,  -3.2952,  -5.4407],
        [-34.1180,  -1.9812,  -2.7920],
        [-27.7303,  -1.9334,  -1.1503],
        [-39.6753,  -2.9777,  -8.5423],
        [-23.3060,  -1.7106,  -0.5430],
        [-25.2237,  -1.5520,  -0.7434],
        [-26.6646,  -1.6172,  -0.9585],
        [-39.7935,  -2.2960,  -8.3227],
        [-39.7925,  -2.2662,  -8.2079],
        [-39.6983,  -2.6067,  -8.4784],
        [-39.5948,  -2.3786,  -8.3330],
        [-39.9427,  -2.3425,  -6.0765],
        [-21.1212,  -1.0106,  -0.4484],
        [-28.2063,  -1.4506,  -1.2438],
        [-39.8281,  -2.1742,  -8.3059],
        [-39.5791,  -2.3173,  -8.0701],
        [-39.6606,  -2.5620,  -7.9407],
        [-19.2674,  -1.0374,  -0.4973],
        [-29.0393,  -2.1726,  -1.4209],
        [-39.8818,  -3.1434,  -6.0738],
        [-39.7531,  -2.2184,  -7.0455],
        [-39.8900,  -2.6295,  -8.6100],
        [-39.9424,  -2.8724,  -8.6121],
        [-39.7611,  -2.5735,  -8.4367],
        [-36.6889,  -2.9040,  -3.6714],
        [-39.5708,  -2.6208,  -8.4238],
        [-35.6888,  -2.0146,  -3.3108],
        [-39.8579,  -2.5130,  -7.5782],
        [-39.7405,  -2.3786,  -7.4978],
        [-39.6492,  -2.3815,  -8.3012],
        [-39.5040,  -2.2500,  -8.2464],
        [-39.6944,  -2.8723,  -5.8098],
        [-39.8217,  -2.1441,  -7.8470],
        [-39.7035,  -2.6611,  -8.4499],
        [-39.6450,  -2.4369,  -8.3263],
        [-39.8228,  -2.3329,  -6.1238],
        [-39.5259,  -2.1013,  -8.1558],
        [-14.1077,  -0.7770,  -1.4345],
        [-31.6861,  -1.6563,  -2.0750],
        [-39.7150,  -2.4639,  -8.3911],
        [-35.4009,  -2.0296,  -3.2127],
        [-39.6774,  -2.2340,  -8.0944],
        [-39.5372,  -2.3872,  -6.3153],
        [-39.7600,  -2.0918,  -8.0698],
        [-20.0954,  -1.0011,  -0.4597],
        [-39.7840,  -2.8826,  -6.4466],
        [-35.9395,  -3.1761,  -3.3993],
        [-39.4821,  -2.4118,  -8.2304],
        [-39.4953,  -2.8432,  -7.7095],
        [-36.9911,  -2.0911,  -3.7853],
        [-34.7771,  -2.5738,  -3.0052],
        [-39.5698,  -2.6209,  -8.4327],
        [-39.8396,  -2.3309,  -8.3214],
        [-24.8302,  -1.7171,  -0.6941],
        [-39.4498,  -2.4397,  -8.3116],
        [-39.6230,  -2.2741,  -8.2033],
        [-33.9979,  -3.0686,  -2.7563],
        [-36.1597,  -3.1447,  -3.4772],
        [-28.6463,  -2.3506,  -1.3358],
        [-39.5623,  -2.8811,  -6.5562],
        [-39.7821,  -2.8765,  -6.8973],
        [-40.0579,  -3.7800,  -9.0265],
        [-40.0022,  -3.6863,  -8.9401],
        [-39.6243,  -3.0870,  -5.7083],
        [-25.3471,  -1.2760,  -0.7690],
        [-11.7614,  -0.5677,  -2.3885],
        [-19.3166,  -0.8829,  -0.4946],
        [ -8.5247,  -0.5299,  -0.5922],
        [-10.5118,  -0.8247,  -2.7264],
        [ -4.5752,  -0.2067,  -0.6185],
        [ -4.5752,  -0.6339,  -9.7548],
        [-12.1283,  -0.8918,  -2.9000],
        [-20.0269,  -1.3105,  -8.3560],
        [ -1.5886,  -0.3581,  -0.5984],
        [ -6.5187,  -0.5464,  -0.9241],
        [-11.8847,  -0.2835,  -2.2556],
        [-16.3204,  -1.0306,  -1.8567],
        [ -3.0826,  -0.5886,  -8.0844],
        [ -3.2364,  -0.5816,  -4.3978],
        [ -3.2364,  -0.2841,  -0.4482],
        [ -0.9268,  -0.5467,  -7.9653],
        [-20.7398,  -1.2112,  -1.6775],
        [-13.8529,  -0.9820,  -3.0552],
        [ -6.3109,  -0.4908,  -0.6703],
        [-18.6371,  -1.2425,  -9.0237],
        [ -8.7625,  -0.7666,  -3.4603],
        [-10.7737,  -0.7498,  -1.3023],
        [-15.3725,  -1.0359,  -2.9071],
        [ -0.2125,  -0.0998,  -0.8847],
        [ -4.9399,  -0.6402,  -9.4410],
        [-16.0014,  -0.9748,  -1.4415],
        [-11.8514,  -0.9200,  -6.4215],
        [ -4.3973,  -0.6048,  -3.2071],
        [ -2.4973,  -0.5727,  -7.0623],
        [ -2.4973,  -0.0590,  -1.5581],
        [ -4.0506,  -0.0847,  -1.5109],
        [ -2.0899,  -0.5442,  -3.3131],
        [ -3.7159,  -0.3613,  -0.5081],
        [ -1.9600,  -0.4318,  -0.9446],
        [  0.7610,  -0.5598,  -1.8114],
        [ -2.6391,  -0.4079,  -0.7305],
        [  0.1971,  -0.5111,  -2.0057],
        [ -3.2364,  -0.2644,  -0.4547],
        [-12.4082,  -0.9480,  -9.5088],
        [ -1.8332,  -0.4716,  -1.3684],
        [-10.7737,  -0.7423,  -1.3043],
        [ -9.4942,  -0.7247,  -1.5984],
        [ -3.5531,  -0.4563,  -0.9024],
        [-17.7790,  -0.6341,  -0.6337],
        [ -8.2898,  -0.7608,  -4.3358],
        [-11.8514,  -0.8150,  -1.5011],
        [ -0.2913,  -0.5377,  -3.8324],
        [ -0.4577,  -0.0836,  -4.2343],
        [  0.5564,  -0.6343,  -1.3298],
        [ -0.2913,  -0.3698,  -0.6598],
        [ -0.8270,  -0.5399,  -4.6845],
        [ -2.7840,  -0.5727,  -6.6313],
        [ -0.5457,  -0.4790,  -1.6083],
        [ -7.3802,  -0.7133,  -4.0289],
        [-12.4082,  -0.8242,  -1.5597],
        [-14.1507,  -0.7098,  -0.5804],
        [-22.5747,  -1.3030,  -2.0970],
        [-18.6371,  -1.2420,  -8.1456],
        [ -9.9970,  -0.8233,  -9.0771],
        [-11.5775,  -0.6297,  -0.5897],
        [ -8.5247,  -0.7641,  -9.1713],
        [  0.2548,  -0.5537,  -4.3840],
        [ -4.5814,  -0.5778,  -8.6475],
        [-12.4082,  -0.9345,  -4.9930],
        [ -5.9041,  -0.6693,  -6.4490],
        [ -1.0296,  -0.5281,  -3.2521],
        [ -5.1333,  -0.7497,  -8.0584],
        [ -6.1060,  -0.6823,  -6.6353],
        [ -9.0034,  -0.4890,  -0.4840]], dtype=torch.float64)

with pyro.plate("map", len(data_y)):
    # normalized probs of class, for every object
    # todo change probs to logits
    pyro.sample("obs", dist.Categorical(logits=probs_of_classes), obs=data_y)

In [170]:
torch.tensor(1) + torch.tensor(5)

tensor(6)