In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import math
import copy
import random

from keras.layers import *
from keras.models import Model, load_model
from keras.optimizers import Adam, Nadam, SGD
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from keras.utils import to_categorical
from keras.preprocessing.sequence import pad_sequences

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import normalize
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_curve, roc_curve
import matplotlib.pyplot as plt
%matplotlib inline 

import seaborn as sns
from sklearn.manifold import TSNE

import tensorflow as tf

# Arrays with constants

In [None]:
classes = np.array([6, 15, 16, 42, 52, 53, 62, 64, 65, 67, 88, 90, 92, 95, 99], dtype='int32')
class_names = ['class_6','class_15','class_16','class_42','class_52','class_53','class_62','class_64','class_65','class_67','class_88','class_90','class_92','class_95','class_99']
real_class_names = ['µLens-Single','TDE','EB','SNII','SNIax','Mira','SNIbc',' KN','M-dwarf','SNIa-91bg','AGN ','SNIa','RRL','SLSN-I','class_99']
# LSST passbands (nm)  u    g    r    i    z    y      
passbands = np.array([357, 477, 621, 754, 871, 1004], dtype='float32')

num_models = 1

limit = 1000000 #limit of samples

# Load Training Data

In [None]:
train_meta = pd.read_csv('../input/PLAsTiCC-2018/training_set_metadata.csv')
train_data = pd.read_csv('../input/PLAsTiCC-2018/training_set.csv')

In [None]:
train_meta

**Meta data for graphic of SNIa objects from ZTF and LSST data**

In [None]:
spec_meta = train_meta[train_meta['target'] == 90].copy()
spec_meta = spec_meta[spec_meta['hostgal_specz'] <= 0.14]
spec_meta
# file_name = './spec_meta_lsst.csv'
# spec_meta.to_csv(file_name)

In [None]:
spec_objects = spec_meta[spec_meta['hostgal_specz'].isin([0.068 , 0.125 , 0.0498, 0.12  , 0.0763, 0.099 , 0.0669, 0.0767,
       0.0732, 0.074 , 0.11  , 0.0845, 0.1181, 0.062 ])]['object_id'].to_numpy()
spec_objects

In [None]:
spec_data = train_data[train_data['object_id'].isin(spec_objects)].copy()
spec_data
# file_name = './spec_data_lsst.csv'
# spec_data.to_csv(file_name)

**Graphic - relation between real redshift and photometric redshift**

In [None]:
plt.figure(figsize=(15,8))
sns.scatterplot(data=train_meta, x="hostgal_specz", y="hostgal_photoz", hue="hostgal_photoz_err")
plt.savefig('./Photoz_err(specsz).png')

**Use only detected where flux differs from template**

In [None]:
train_data = train_data[train_data['detected']==1]

**Select two filters**

In [None]:
train_data_ztf2 = train_data.loc[(train_data['passband'] > 0)&(train_data['passband'] < 3)]
train_data_ztf2

**Select objects with more then 5 points**

In [None]:
count = (train_data_ztf2.groupby(['object_id']).count()>=6)
test_objs = count[count['mjd']].index
test_objs
train_data_ztf2 = train_data_ztf2[train_data_ztf2['object_id'].isin(test_objs)]
train_data_ztf2

**Mean number of points for objects**

In [None]:
mean_points = train_data_ztf2.groupby('object_id').count()
mean_points['mjd'].to_numpy().mean()

**Make descriptions for 7 classes**

In [None]:
df = train_data_ztf2.copy()
meta = train_meta.drop(columns = ['ra', 'decl', 'gal_l', 'gal_b', 'distmod'])
df = df.merge(meta, how = 'left', on = 'object_id')
df['target'].unique()

In [None]:
def describe(df):
    return pd.concat([df.describe().T,
                      df.median().rename('median')], axis=1).T

In [None]:
df_stat = describe(df[df['target']==88])
df_stat.drop(columns = ['object_id', 'target'], inplace = True)
df_stat.to_excel("class_AGN.xlsx")
df_stat

# Flux Density

In [None]:
# three objects from class 42
objects = train_meta.loc[train_meta['target'].isin([42])]
idxs = np.random.randint(len(objects), size=3)
rand_objects = {int(objects.iloc[idx]['object_id']):int(objects.iloc[idx]['target']) for idx in idxs}
print("Random objects:",rand_objects)
df = train_data.loc[train_data['object_id'].isin(rand_objects.keys())]

In [None]:
plt.figure(figsize=(15,10))
groups = df.groupby('object_id')
for g in groups:
    sns.kdeplot(g[1]['flux'], shade = True, legend = False)

# Graphics of mean flux and lightcurves

In [None]:
groups = train_data_ztf2.groupby('object_id')
df = pd.DataFrame()
mean_g = groups.mean()
df['object_id'] = mean_g.index
df['m_flux'] = mean_g['flux'].to_numpy()
df['target'] = df.merge(train_meta, how = 'left', on = 'object_id')['target']

In [None]:
# exclude abnormally large mean flux
df = df[df['m_flux']<120000]
sns.catplot(x="target", y="m_flux", data=df, height=7, aspect=2)
plt.savefig('./mean_per_classes')

In [None]:
spec_df = train_data_ztf2.merge(train_meta[['object_id','target']], how = 'left', on = 'object_id')

In [None]:
object_ids = spec_df.groupby('target').head(1)['object_id'].to_numpy()
object_ids
spec_df = spec_df.loc[spec_df['object_id'].isin(object_ids)]
spec_df = spec_df[spec_df['target'].isin([90,42,15,67,52,62,95,88])]
spec_df

In [None]:
g = sns.relplot(x="mjd", y="flux",hue = "passband", col="target",palette = 'viridis', data=spec_df)
g.fig.set_figwidth(91)
g.fig.set_figheight(9)
plt.savefig('./light_curves_lsst')

# Preprocessing

In [None]:
def get_freq(df):
    
    all_y = np.array(df['target'], dtype = 'int32')

    y_count = np.unique(all_y, return_counts=True)

    freq = np.ones(len(classes)-1)

    freq = y_count[1]/ all_y.shape[0]

    return {y_count[0][i]:freq[i] for i in range(len(classes)-1)}

In [None]:
freq_dir = get_freq(train_meta)
freq_dir[99] = 1.
freq_dir

In [None]:
def get_keras_data(itemslist, test = False):
    
    if not(test):
        print("TRAINING")
    else:
        print("TESTING")
    # sequence_len = 256
    sequence_len = 30
    keys = itemslist[0].keys()
    X = {
            'id': np.array([i['id'] for i in itemslist], dtype='int32'),
            'meta': np.array([i['meta'] for i in itemslist]),
            'band': pad_sequences([i['band'] for i in itemslist], maxlen=sequence_len, dtype='int32'),
            'hist': pad_sequences([i['hist'] for i in itemslist], maxlen=sequence_len, dtype='float32'),
        }
    for key in X.keys():
        print('key = {0}, value_shape = {1}' .format(key, X[key].shape))

    if not(test):
        Y = to_categorical([i['target'] for i in itemslist], num_classes=len(classes))
        print('target, value_shape = {0}' .format(Y.shape))
        return X, Y
    else: 
        return X


In [None]:
def get_data(data_df, meta_df, test = False):

    samples = []
    groups = data_df.groupby('object_id')
    flag = 1 #to watch on first sample dir

    for g in groups:

        id = g[0]

        sample = {}
        sample['id'] = int(id)

        #object_id,ra,decl,gal_l,gal_b,ddf,hostgal_specz,hostgal_photoz,hostgal_photoz_err,distmod,mwebv,target
        #615, 349.046051,-61.943836,320.796530,-51.753706,1,0.0000,0.0000,0.0000,nan,0.017,92
        meta = meta_df.loc[meta_df['object_id'] == id]

        if not(test):
            if 'target' in meta:
                sample['target'] = np.where(classes == int(meta['target']))[0][0]
            else:
                sample['target'] = len(classes) - 1

        sample['meta'] = np.zeros(5, dtype = 'float32')
        
    
        sample['meta'][0] = meta['ddf']
        sample['meta'][1] = meta['hostgal_photoz']
        sample['meta'][2] = meta['hostgal_photoz_err']
        sample['meta'][3] = meta['mwebv']
        sample['meta'][4] = float(meta['hostgal_photoz']) > 0

        sample['specz'] = float(meta['hostgal_specz'])

        z = float(sample['meta'][1])

        #object_id,mjd,passband,flux,flux_err,detected
        #615,59750.4229,2,-544.810303,3.622952,1


        mjd      = np.array(g[1]['mjd'],      dtype='float32')
        band     = np.array(g[1]['passband'], dtype='int32')
        flux     = np.array(g[1]['flux'],     dtype='float32')
        flux_err = np.array(g[1]['flux_err'], dtype='float32')
        detected = np.array(g[1]['detected'], dtype='float32')

        mjd -= mjd[0]
        mjd /= 100 # Earth time shift in day*100
        mjd /= (z + 1) # Object time shift in day*100


        received_wavelength = passbands[band] # Earth wavelength in nm
        source_wavelength = received_wavelength / (z + 1) # Object wavelength in nm


        sample['band'] = band + 1
        
        flux_max = np.max(flux)
        flux_min = np.min(flux)
        flux_pow = flux_max - flux_min

        sample['hist'] = np.zeros((flux.shape[0], 7), dtype='float32')
        
        sample['hist'][:,0] = mjd
        sample['hist'][:,1] = flux/(flux_pow+1)
        sample['hist'][:,2] = flux_err/(flux_pow+1)
        sample['hist'][:,3] = detected
        sample['hist'][:,4] = np.ediff1d(mjd, to_begin = [0])
        sample['hist'][:,5] = (source_wavelength/1000)
        sample['hist'][:,6] = (received_wavelength/1000)


#         sample['meta'][5] = flux_pow / 10 #exclude this feature for ZTF data!!
        
        if flag:
            print("First sample:")
            print(sample.keys())
            print("id=", sample['id'])
            # print("target=", sample['target'])
            print("meta=", sample['meta'])
            print("specz=", sample['specz'])
            print("band shape=", sample['band'].shape)
            print("hist shape=", sample['hist'].shape)
            flag = 0
            
        samples.append(sample)
        
        if len(samples) % 1000 == 0:
            print('Converting data {0}'.format(len(samples)), end='\r')

        if len(samples) >= limit:
            break
    
    print('Full data number {0}'.format(len(samples)), end='\r')
    
    return samples

# Augmentation

In [None]:
def copy_sample(s, augmentate=True):
    
    c = copy.deepcopy(s)

    if not augmentate:
        return c

    band = []
    hist = []

    drop_rate = 0.001

#   drop some records
    for k in range(s['band'].shape[0]):
        if random.uniform(0, 1) >= drop_rate:
            band.append(s['band'][k])
            hist.append(s['hist'][k])

    c['hist'] = np.array(hist, dtype='float32')
    c['hist'][...,4] = np.ediff1d(c['hist'][...,0], to_begin = [0])
    c['band'] = np.array(band, dtype='int32')
            
    new_z = random.normalvariate(c['meta'][1], c['meta'][2] / 1.5)
    new_z = np.clip(new_z,0,5)

    dt = (1 + c['meta'][1]) / (1 + new_z)
    
    c['meta'][1] = new_z

    # augmentation for flux
    c['hist'][:,1] = np.random.normal(c['hist'][:,1], c['hist'][:,2] / 1.5)

    # multiply time intervals and wavelength to apply augmentation for red shift
    c['hist'][:,0] *= dt
    c['hist'][:,4] *= dt
    c['hist'][:,5] *= dt
    
    return c

In [None]:
def augmentate(samples, count):

    res = []
    index = 0
    for s in samples:

        index += 1
        
        if index % 1000 == 0:
            print('Augmenting {0}/{1}   '.format(index, len(samples)), end='\r')

        for i in range(0, count):
            res.append(copy_sample(s))
            
    return res

# RNN Model

In [None]:
def get_model(X, Y, size=80):

    hist_input = Input(shape=X['hist'][0].shape, name='hist')
    meta_input = Input(shape=X['meta'][0].shape, name='meta')
    band_input = Input(shape=X['band'][0].shape, name='band')
    
    band_emb = Embedding(8, 8)(band_input)
    
    hist = concatenate([hist_input, band_emb])
    hist = TimeDistributed(Dense(40, activation='relu'))(hist)
    
    rnn = Bidirectional(GRU(size, return_sequences=True))(hist)
    rnn1 = SpatialDropout1D(0.5)(rnn)

    gmp = GlobalMaxPool1D()(rnn1)
    gmp1 = Dropout(0.5)(gmp)

    x = concatenate([meta_input, gmp1])
    x = Dense(128, activation='relu')(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)

    output = Dense(15, activation='softmax')(x)

    model = Model(inputs=[hist_input, meta_input, band_input], outputs=output)
    hidden_model = Model(inputs = [hist_input, meta_input, band_input], outputs = gmp)

    return model, hidden_model

**Loss function**

In [None]:
def mywloss(y_true,y_pred):
    freq = np.array(list(freq_dir.values()))
    yc=tf.clip_by_value(y_pred,1e-15,1-1e-15)
    loss=-(tf.reduce_mean(tf.reduce_mean(y_true*tf.math.log(yc),axis=0)/freq))
    return loss

In [None]:
def train_model(i, samples_train, samples_valid):
    
    augment_count = 25
    samples_train += augmentate(samples_train, augment_count)
    
    patience = len(samples_train)//2000+10

    print("training data")
    train_x, train_y = get_keras_data(samples_train)
    del samples_train
    print("testing data")
    valid_x, valid_y = get_keras_data(samples_valid)
    
    #del samples_valid

    model,rnn_part = get_model(train_x, train_y)

    if i == 1: model.summary()
    model.compile(optimizer='nadam', loss=mywloss, metrics=['accuracy'])
    rnn_part.compile(optimizer='nadam', loss=mywloss, metrics=['accuracy'])

    print('Training model {0} of {1}, Patience: {2}'.format(i, num_models, patience))
    filename = './model_{0}.hdf5'.format(i)
    max_epochs = 300
    
    callbacks = [EarlyStopping(patience=patience, verbose=1), ModelCheckpoint(filename, save_best_only=True)]
    
    model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=max_epochs, batch_size=5000, callbacks=callbacks, verbose=2)
    
    # return history
    model.save('./model_last.hdf5')
    rnn_part.save('./rnn_part.hdf5')
    
    #model = load_model(filename, custom_objects={'mywloss': mywloss})

In [None]:
samples = get_data(train_data_ztf2, train_meta)
# samples = get_data(train_data, train_meta)

# Flux Density after normalization

In [None]:
plt.figure(figsize=(15,10))
for obj in samples:
    if obj['id'] in rand_objects.keys():
        sns.kdeplot(obj['hist'][:,1], shade = True, legend = False)

# Training

In [None]:
valid_size = 0.1

samples_train, samples_valid = train_test_split(samples, test_size=0.1, random_state=42)

train_model(num_models, samples_train, samples_valid)

# Loss Curves

In [None]:
#plotting curves
plt.plot(history.history['loss'], label='myloss(training data)')
plt.plot(history.history['val_loss'], label='myloss(validation data)')
plt.title('Training on data with 2 filters')
plt.ylabel('loss value')
plt.xlabel('No. epoch')
plt.legend(loc="upper left")
plt.savefig('./Loss_Curves_2filters.png')
plt.show()

# t-SNE

In [None]:
valid_x, valid_y = get_keras_data(samples_valid)
filename = './rnn_part.hdf5'
rnn_part = load_model(filename, custom_objects={'mywloss': mywloss})
rnn_part.summary()

In [None]:
# mean number of objects in class
neighbors = np.unique(np.argmax(valid_y, axis = 1), return_counts=True)
neighbors[1].mean()

In [None]:
# objects per class
class_density = {real_class_names[num]:count for num,count in zip(neighbors[0],neighbors[1])}
class_density                    

In [None]:
# get hidden layer features
rnn_res = rnn_part.predict(valid_x, batch_size=1000)
print("features shape = ", rnn_res.shape)

# get 2d embeddings of features
tsne = TSNE(n_components=2,perplexity = 56, n_iter = 4000, init='pca')

print("start fitting t-SNE")
res_tsne = tsne.fit_transform(rnn_res)
print("Final feature shape = ", res_tsne.shape)

In [None]:
# # color palette visualization
# num2class_name = dict(zip(classes[:-1], real_class_names))
# print(num2class_name)
# sns.palplot(sns.color_palette("Spectral", 14))

In [None]:
color_values = sns.color_palette("Spectral", 14)

color_map = dict(zip(np.arange(14), color_values))

# color_list = [color_map[class_n] for class_n in np.argmax(valid_y, axis = 1)]
class_list = [num2class_name[classes[num]] for num in np.argmax(valid_y, axis = 1)]

fig = plt.figure(figsize=(15,10))
# plt.scatter(res_tsne[:,0], res_tsne[:,1], c = color_list, marker='.')
# plt.legend(prop={'size':9})
plt.title('t-SNE embedding: layer - Bidirectional GRU')
sns.scatterplot(x=res_tsne[:,0], y=res_tsne[:,1], hue=class_list, palette=color_values, legend="full")
plt.savefig('./t-SNE_deep_layer(2).png')

**Graphics of embedded features for six classes**

In [None]:
class_list = np.array(class_list)
names = ['TDE', 'SNIbc', 'M-dwarf', 'EB', 'SNII', 'SNIa']
fig, axes = plt.subplots(2, 3, figsize=(20, 10))
for name, ax in zip(names, axes.flatten()):
    hue = np.where(class_list == name, class_list, 'others')
    sns.scatterplot(x=res_tsne[:,0], y=res_tsne[:,1], hue=hue, palette=color_values[5::8], ax = ax)
plt.savefig('./t-SNE_deep_layer_6classes(2).png')

# Results on validation data

In [None]:
def multi_weighted_logloss(y_valid, y_pred, freq):
    """
    @author olivier https://www.kaggle.com/ogrellier
    multi logloss for PLAsTiCC challenge
    """
    class_weight = {6: 1, 15: 2, 16: 1, 42: 1, 52: 1, 53: 1, 62: 1, 64: 2, 65: 1, 67: 1, 88: 1, 90: 1, 92: 1, 95: 1, 99: 1}
    # Normalize rows and limit y_preds to 1e-15, 1-1e-15
    y_pred = np.clip(a=y_pred, a_min=1e-15, a_max=1-1e-15)
    # Transform to log
    y_p_log = np.log(y_pred)
    # Get the log for ones, .values is used to drop the index of DataFrames
    # Exclude class 99 for now, since there is no class99 in the training set 
    # we gave a special process for that class
    y_log_ones = np.sum(y_valid * y_p_log, axis=0)
    # Get the number of positives for each class
    nb_pos = freq

    # Weight average and divide by the number of positives
    class_arr = np.array([class_weight[k] for k in sorted(class_weight.keys())])
    y_w = y_log_ones * class_arr / nb_pos    
    loss = - np.sum(y_w) / np.sum(class_arr)
    return loss / y_valid.shape[0]

In [None]:
print("testing data")
valid_x, valid_y = get_keras_data(samples_valid)

#filename = './model_1.hdf5'
filename = './model_last.hdf5'

model = load_model(filename, custom_objects={'mywloss': mywloss})

freq = np.array(list(freq_dir.values())) # frequencies of classes
preds = model.predict(valid_x, batch_size=1000)
loss = multi_weighted_logloss(valid_y, preds, freq)
acc = accuracy_score(np.argmax(valid_y, axis=1), np.argmax(preds,axis=1))
print('MW Loss: {0:.4f}, Accuracy: {1:.4f}'.format(loss, acc))


In [None]:
# class_names: 'µLens-Single','TDE','EB',' SNII',' SNIax','Mira','SNIbc',' KN','M-dwarf','SNIa-91bg','AGN ','SNIa',' RRL',' SLSN-I','class_99'

**Confusion matrix**

In [None]:
cm = confusion_matrix(np.argmax(valid_y, axis = 1), np.argmax(preds, axis = 1), labels = np.arange(14))

# P vs R barplot

In [None]:
#precision recall
recall = np.diag(cm) / np.sum(cm, axis = 1)
precision = np.diag(cm) / np.sum(cm, axis = 0)

fig, ax = plt.subplots(figsize=(15,5))
labels = real_class_names[:-1]
x = np.arange(len(classes)-1)  # the label locations
width = 0.35  # the width of the bars
rects1 = ax.bar(x - width/2, precision, width, label='Precision')
rects2 = ax.bar(x + width/2, recall, width, label='Recall')

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('Scores')
ax.set_title('Precision/Recall per class')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
plt.savefig('./PvsR_barplot.png')

In [None]:
#normalize conf matrix
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
annot = np.around(cm, 2)

fig, ax = plt.subplots(figsize=(10,9))
sns.heatmap(cm, xticklabels=real_class_names[:-1], yticklabels=real_class_names[:-1], cmap='Blues', annot=annot)
ax.set_xlabel('Predicted Label')
ax.set_ylabel('True Label')
ax.set_aspect('equal')
fig.tight_layout()
plt.savefig('./conf_matrix.png')

# P vs R curve

In [None]:
# precision recall curve
precision = dict()
recall = dict()
fig, ax = plt.subplots(figsize=(15,10))
for i in range(len(classes)-1):
    precision[i], recall[i], _ = precision_recall_curve(valid_y[:, i], preds[:, i])
    plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(real_class_names[i]))
    
plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs recall curve")
plt.savefig('./PvsR_curves.png')

# ROC curve

In [None]:
# roc curve
fpr = dict()
tpr = dict()

fig, ax = plt.subplots(figsize=(15,10))
for i in range(len(classes)-1):
    fpr[i], tpr[i], _ = roc_curve(valid_y[:, i], preds[:, i])
    plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(real_class_names[i]))

plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.savefig('./ROC_curves(2).png')

In [None]:
tf.keras.utils.plot_model(model, to_file="./model(2).png", show_shapes=True)

# Test Data

In [None]:
test_data = pd.read_csv('../input/PLAsTiCC-2018/test_set_batch1.csv')
test_target = pd.read_csv('../input/PLAsTiCC-2018/sample_submission.csv')
test_meta = pd.read_csv('../input/PLAsTiCC-2018/test_set_metadata.csv')

In [None]:
test_meta.head()

In [None]:
max_id = test_data['object_id'].max()+1
print(max_id)

In [None]:
test_target = test_target.loc[test_target['object_id'] < max_id]
test_target.head()

In [None]:
samples = get_data(test_data, test_meta, test = True)

In [None]:
print("testing data")
valid_x = get_keras_data(samples, test = True)

In [None]:
valid_y = (test_target.drop(['object_id'], axis = 1)).to_numpy()

In [None]:
filename = './model_1.hdf5'
#filename = './model_last.hdf5'

model = load_model(filename, custom_objects={'mywloss': mywloss})

freq = np.array(list(freq_dir.values())) # frequencies of classes
preds = model.predict(valid_x, batch_size=5000)
loss = multi_weighted_logloss(valid_y, preds, freq)
acc = accuracy_score(np.argmax(valid_y, axis=1), np.argmax(preds,axis=1))
print('MW Loss: {0:.4f}, Accuracy: {1:.4f}'.format(loss, acc))

In [None]:
cm = confusion_matrix(np.argmax(valid_y, axis = 1), np.argmax(preds, axis = 1), labels = np.arange(15))
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
annot = np.around(cm, 2)

fig, ax = plt.subplots(figsize=(10,9))
sns.heatmap(cm, xticklabels=real_class_names, yticklabels=real_class_names, cmap='Blues', annot=annot)
ax.set_xlabel('Predicted Label')
ax.set_ylabel('True Label')
ax.set_aspect('equal')
fig.tight_layout()