## Covid Project

In this data science project we want to use data from the COWAS data base (uploaded at Kaggle: https://www.kaggle.com/praveengovi/coronahack-respiratory-sound-dataset) to make a 


### Data Structure

There are 1397 cases of which 56 are positive ones. Each case is composed of 9 independing recordings 
['counting-normal','counting-fast','breathing-deep','breathing-shallow','cough-heavy','cough-shallow','vowel-a','vowel-e','vowel-o']

### Potential Solution

Using an auto-encoder approach (out of distribution), training on "healthy" cases.
Proposed solution (https://github.com/moiseshorta/MelSpecVAE).

## #Chunk 1
### Libraries

In [1]:

#Data visualization

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

#Audio Analysis
import glob
import IPython
import librosa
import librosa.display
import tensorflow as tf
from tensorflow import keras

#path
import os

#

## #Chunk 2
### Import Meta data (file path information)

In [2]:
# import meta data
# Meta data csv contain different additional information about each case.
# One column contains the path to the .wav files of each case
df_meta = pd.read_csv('./CoronaHack-Respiratory-Sound-Dataset/Corona-Hack-Respiratory-Sound-Metadata.csv')
df_meta.info(), df_meta.shape


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1397 entries, 0 to 1396
Data columns (total 37 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   USER_ID                 1397 non-null   object 
 1   COUNTRY                 1397 non-null   object 
 2   AGE                     1397 non-null   int64  
 3   COVID_STATUS            1396 non-null   object 
 4   ENGLISH_PROFICIENCY     1397 non-null   object 
 5   GENDER                  1397 non-null   object 
 6   COUNTY_RO_STATE         1397 non-null   object 
 7   CITY_LOCALITY           1228 non-null   object 
 8   Diabetes                1397 non-null   int64  
 9   Asthma                  1397 non-null   int64  
 10  Smoker                  1397 non-null   int64  
 11  Hypertension            1397 non-null   int64  
 12  Fever                   1397 non-null   int64  
 13  Returning_User          1397 non-null   int64  
 14  Using_Mask              1397 non-null   

(None, (1397, 37))

In [3]:
df_meta.head()

Unnamed: 0,USER_ID,COUNTRY,AGE,COVID_STATUS,ENGLISH_PROFICIENCY,GENDER,COUNTY_RO_STATE,CITY_LOCALITY,Diabetes,Asthma,...,DATES,breathing-deep,breathing-shallow,cough-heavy,cough-shallow,counting-fast,counting-normal,vowel-a,vowel-e,vowel-o
0,vK2bLRNzllXNeyOMudnNSL5cfpG2,India,24,healthy,Y,M,Karnataka,Bangalore,0,0,...,20200413,/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cf...,/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cf...,/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cf...,/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cf...,/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cf...,/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cf...,/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cf...,/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cf...,/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cf...
1,bjA2KpSxneNskrLBeqi4bqoTDQl2,India,72,healthy,Y,M,Maharashtra,Thane,0,0,...,20200413,/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTD...,/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTD...,/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTD...,/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTD...,/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTD...,/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTD...,/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTD...,/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTD...,/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTD...
2,FSzobvJqOXf0rI6X05cHqOiU9Mu2,India,54,healthy,Y,M,Maharashtra,Thane West,0,0,...,20200413,/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9...,/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9...,/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9...,/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9...,/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9...,/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9...,/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9...,/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9...,/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9...
3,EqDWckxbsETyHUeBLQ8jLtxlhir2,India,31,healthy,Y,M,Karnataka,Bangalore,0,0,...,20200413,/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlh...,/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlh...,/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlh...,/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlh...,/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlh...,/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlh...,/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlh...,/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlh...,/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlh...
4,FGRDO4IBbAejR0WHD5YbkXTCasg2,India,26,healthy,Y,M,Haryana,gurgaon,0,0,...,20200413,/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCa...,/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCa...,/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCa...,/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCa...,/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCa...,/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCa...,/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCa...,/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCa...,/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCa...


## #Chunk 3
### Get the label for each case

In [3]:
#Get the label (healthy / COVID) 

#split COVID STATUS column to get labels in column 'split'
df_meta['split'] = df_meta['COVID_STATUS'].str.split('_').str.get(0)
#Check for NA
df_meta.loc[:,'counting-normal'].isna().sum()
df_meta.loc[:,'split'].value_counts()

#Generate a dict to re-categorize the split column
cat_dict = {'healthy':0,'no':0,'resp':0,'recovered':0,'positive':1}

#map cat_dict to split column 
df_meta.loc[:,'split'] =  df_meta.loc[:,'split'].map(cat_dict)
df_meta2 = df_meta.dropna(subset=['split'])
df_meta2.loc[:,'split'] = df_meta2.loc[:,'split'].astype('int32')


#Extract positive USER ID
df_meta_positives = df_meta[df_meta['split'] == 1]
df_meta_negatives = df_meta[df_meta['split'] == 0]

positives = list(df_meta_positives['USER_ID'])
negatives = list(df_meta_negatives['USER_ID'])
len(positives),len(negatives)
#positives

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  isetter(ilocs[0], value)


(56, 1340)

## #Chunk 5
### generate Function to create the input data for auto-encoder

In [29]:
# Create function to load and prepare data for input 
# here we want to use the 9 recordings as separate features but grouped per case as input to the auto-encoder 

#names of 9 recordings per each case (extracted from the csv meta data file from )
#names_input = ['counting-normal','counting-fast','breathing-deep','breathing-shallow','cough-heavy','cough-shallow','vowel-a','vowel-e','vowel-o']
#label column from the meta data csv (#Chunk 3)
name_label = 'split'

def create_input_label(df=df_meta2,names=names_input,name_label=name_label):
    input_dic = {} #Use a dictionnary to put in the 9 records per case
    base_path = './CoronaHack-Respiratory-Sound-Dataset'
    
    for index,name in enumerate(names):
        #print(index,name)
        print("Create input run")
        path_list = df[name].tolist()
        print(path_list[:10])
        path_name = []
        for dir_name in path_list:
            path_name.append(base_path+str(dir_name))

        print(path_name[:10])
        print("Sound paths convert to tensor")
        sound_paths_tensor = tf.convert_to_tensor(path_name, dtype=tf.string) #convert to tensor

        print("Sound PATH", sound_paths_tensor[0])
        print("Sound Dataset from tensor slices")
        sound = tf.data.Dataset.from_tensor_slices(sound_paths_tensor)
        print("Sound PATH from slices", sound[0])
        #sound = tf.data.Dataset.from_generator(lambda sample: preprocess_other(sample).batch(32), output_types=tf.int32, output_shapes = (64,64,1),)
        print("Calling preprocessing")
        print("SOUNDD", sound)
        input_dic['x_{}'.format(index)] = sound.map(lambda sample: preprocess_other(sample)) #generating the names of recordings(features x_0 till x_8) in batch mode


    path_label = df[name_label]
    #print(path_label)
    y = tf.convert_to_tensor(path_label, dtype=tf.int16)

    return input_dic,y

    

In [30]:
x,y = create_input_label()
x = list(x.values())
x

Create input run
['/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cfpG2/counting-normal.wav', '/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTDQl2/counting-normal.wav', '/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9Mu2/counting-normal.wav', '/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlhir2/counting-normal.wav', '/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCasg2/counting-normal.wav', '/data/train/20200413/htQzROl26OWQpIYFDzv11F79PLR2/counting-normal.wav', '/data/train/20200413/pW9mCAeWYiMoM7wW7riLvNRbYDO2/counting-normal.wav', '/data/train/20200413/Eu11s84cuBTiPXTAtVf9mj3GkqA2/counting-normal.wav', '/data/train/20200413/L7S8iIPKgiO6QWLC3mGkROCMa0s1/counting-normal.wav', '/data/train/20200413/eP8gEM0KcBU6S5JpMdycX74KP3p2/counting-normal.wav']
['./CoronaHack-Respiratory-Sound-Dataset/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cfpG2/counting-normal.wav', './CoronaHack-Respiratory-Sound-Dataset/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTDQl2/counting-normal.wav', './CoronaHack-Respiratory-Sou

TypeError: 'TensorSliceDataset' object does not support indexing

## #Chunk 4
### Define Function for .wav import and preprocessing 

In [4]:
# Write function for import and preprocessing of all 9 .wav files per case (code adapted from Tristan classes) 

import cv2
def preprocess_other(sample):
  print("Start preprocessing, setting up the shape of sample")
  print("Sample", sample)
  
  audio = sample
  #label = sample['label']
  audio = tf.reshape(sample, [-1])
        
  
  print("PY-PREPROCESS set audio file as float", type(audio))
  audio = tf.cast(audio, tf.float32) #set audio file as float
  #audio = audio[24500:5000+len(audio)//10]
  # Plot audio amplitude
  # plt.figure(figsize=(10,15))
  # plt.plot(audio)
  # plt.show()
  # plt.close()
  
  print(audio)
  print("PY-PREPROCESS generate the mel spectrogram")
  #generate the mel spectrogram
  spectrogram = tfio.audio.spectrogram( 
      audio, nfft=1024, window=1024, stride=64
  )

  spectrogram = tfio.audio.melscale(
      spectrogram, rate=8000, mels=64, fmin=0, fmax=2000 #mels = bins, fmin,fmax = frequences
  )

  print("PY-PREPROCESS devide by np.max(audio)")
  spectrogram /= tf.math.reduce_max(spectrogram) #normalization
  spectrogram = tf.expand_dims(spectrogram, axis=-1) #add dimension 2D -> 3D
  spectrogram = tf.image.resize(spectrogram, (image_target_height, image_target_height)) #resize in two dimensions
  spectrogram = tf.transpose(spectrogram, perm=(1, 0, 2)) #transpose the first two axis
  spectrogram = spectrogram[::-1, :, :] #flip the first axis(frequency)

  # plt.figure(figsize=(10,15))
  # plt.imshow(spectrogram[::-1,:], cmap='inferno') #flipping upside down
  # plt.show()
  # plt.close()
  
  

  # RESHAPE TO FIT VAE MODEL, RESHAPING THE NORMAL FINAL OUTPUT (DATASET) IS NOT POSSIBLE SO WE DO IT HERE
  # WHILE IT´S STILL A TENSOR
  # 
  #spectrogram = tf.reshape(spectrogram, [-1 ,28, 28, 1])

  print("SPRECTROGRAM: ", spectrogram)
  
  return spectrogram

  print("PREPROCESS - apply py_preprocess_audio function")
  spectrogram = tf.py_function(py_preprocess_audio, [audio], tf.float32) #apply py_process_audio function 
  print("PREPROCESS - set shape, include channel dimension")
  spectrogram.set_shape((image_target_height, image_target_width, 1)) #set shape, include channel dimension

  return spectrogram#, label

In [5]:
# Experimental version of above

import matplotlib.pyplot as plt
import tensorflow_io as tfio
# Create function to load and prepare data for input 
# here we want to use the 9 recordings as separate features but grouped per case as input to the auto-encoder 

#names of 9 recordings per each case (extracted from the csv meta data file from )
#names_input = ['counting-normal','counting-fast','breathing-deep','breathing-shallow','cough-heavy','cough-shallow','vowel-a','vowel-e','vowel-o']
names_input = ['counting-normal']
#label column from the meta data csv (#Chunk 3)
name_label = 'split'
image_target_height, image_target_width = 28, 28

IS_VAE = True

def create_input_label2(df=df_meta2,names=names_input,name_label=name_label):
    input_dic = {} #Use a dictionnary to put in the 9 records per case
    base_path = './CoronaHack-Respiratory-Sound-Dataset'
    for index,name in enumerate(names):
        print(index,name)
        print("create path list")
        path_list = df[name].tolist()
        print(path_list[:10])

        path_name = []
        print("create path name")
        for dir_name in path_list:
            if dir_name is not None:
                path_name.append(base_path+str(dir_name))

        #path_name = base_path+str(path_list[0])
        print("create sound tensor")
        
            
        sound_tensor_list = [tfio.audio.AudioIOTensor(sound_path).to_tensor()[:300000] for sound_path in path_name]
        sound_rate_tensor_list = tfio.audio.AudioIOTensor(path_name[0]).rate
        print("DIRTY", len(sound_tensor_list))
        sound_tensor_list_clean = [sound_tensor for sound_tensor in sound_tensor_list if sound_tensor.shape[0] == 300000]
        print("CLEAN", len(sound_tensor_list_clean))


        print("SHAPE ME", sound_tensor_list[0][:100000].shape)
        print("RATE ME", sound_rate_tensor_list)
        print("create Sound Slices")
        sound_slices = tf.data.Dataset.from_tensor_slices(sound_tensor_list_clean)


        print("create input dictionary")
        input_dic['x_{}'.format(index)] = sound_slices.map(lambda sample: preprocess_other(sample)) #generating the names of recordings(features x_0 till x_8) in batch mode
        break
       
    
    path_label = df[name_label]
    print(path_label)
    y = tf.convert_to_tensor(path_label, dtype=tf.int16)

    return input_dic, y

    

## #Chunk 6
### test the output from function

In [6]:
x_, y = create_input_label2()
x_ = list(x_.values())
x_[0].batch(256)

0 counting-normal
create path list
['/data/train/20200413/vK2bLRNzllXNeyOMudnNSL5cfpG2/counting-normal.wav', '/data/train/20200413/bjA2KpSxneNskrLBeqi4bqoTDQl2/counting-normal.wav', '/data/train/20200413/FSzobvJqOXf0rI6X05cHqOiU9Mu2/counting-normal.wav', '/data/train/20200413/EqDWckxbsETyHUeBLQ8jLtxlhir2/counting-normal.wav', '/data/train/20200413/FGRDO4IBbAejR0WHD5YbkXTCasg2/counting-normal.wav', '/data/train/20200413/htQzROl26OWQpIYFDzv11F79PLR2/counting-normal.wav', '/data/train/20200413/pW9mCAeWYiMoM7wW7riLvNRbYDO2/counting-normal.wav', '/data/train/20200413/Eu11s84cuBTiPXTAtVf9mj3GkqA2/counting-normal.wav', '/data/train/20200413/L7S8iIPKgiO6QWLC3mGkROCMa0s1/counting-normal.wav', '/data/train/20200413/eP8gEM0KcBU6S5JpMdycX74KP3p2/counting-normal.wav']
create path name
create sound tensor
DIRTY 1396
CLEAN 1328
SHAPE ME (100000, 1)
RATE ME tf.Tensor(48000, shape=(), dtype=int32)
create Sound Slices
create input dictionary
Start preprocessing, setting up the shape of sample
Sample Ten

<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

## #Chunk 7
### Built the auto-encoder architecture (code adapted from Tristan Class)

In [8]:
from tensorflow.keras import models, layers
image_target_height, image_target_width
class AutoEncoder(tf.keras.Model):
    
    def __init__(self, latent_dim):
        super().__init__()

        self.latent_dim = latent_dim

        # Encoder
        self.encoder_reshape = layers.Reshape((image_target_height * image_target_width,)) #Shape as 64,64,1
        self.encoder_fc1 = layers.Dense(32, activation="relu")
        self.encoder_fc2 = layers.Dense(latent_dim, activation="relu")

        # Decoder
        self.decoder_fc1 = layers.Dense(32, activation='relu')
        self.decoder_fc2 = layers.Dense(image_target_height * image_target_width, activation='sigmoid')
        self.decoder_reshape = layers.Reshape((image_target_height, image_target_width,1))

        self._build_graph()

    def _build_graph(self):
        input_shape = (image_target_height, image_target_width, 1)
        self.build((None,)+ input_shape)
        inputs = tf.keras.Input(shape=input_shape)
        _= self.call(inputs)

    def call(self, x):
        z = self.encode(x)
        x_new = self.decode(z)
        return x_new

    def encode(self, x):
        x = self.encoder_reshape(x)
        x = self.encoder_fc1(x)
        z = self.encoder_fc2(x)
        return z
   

    def decode(self, z):
        z = self.decoder_fc1(z)
        z = self.decoder_fc2(z)
        x = self.decoder_reshape(z)
        return x

autoencoder = AutoEncoder(32)
autoencoder.summary()

autoencoder.compile(
    optimizer='rmsprop',
    loss='binary_crossentropy'
)

Model: "auto_encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape (Reshape)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 32)                25120     
_________________________________________________________________
dense_1 (Dense)              (None, 32)                1056      
_________________________________________________________________
dense_2 (Dense)              (None, 32)                1056      
_________________________________________________________________
dense_3 (Dense)              (None, 784)               25872     
_________________________________________________________________
reshape_1 (Reshape)          (None, 28, 28, 1)         0         
Total params: 53,104
Trainable params: 53,104
Non-trainable params: 0
__________________________________________________

In [9]:
autoencoder.summary

<bound method Model.summary of <__main__.AutoEncoder object at 0x0000019609D97D68>>

## #Chunk 8
### Train the model

Here we try to input the 9 features (recordings per case) into the model architecture

In [48]:
#list(x[0].as_numpy_iterator())
print(x[0])
print(x[0].batch(256))
print(x[0].take(6))
#dataset

<MapDataset shapes: (1, 28, 28, 1), types: tf.float32>
<BatchDataset shapes: (None, 1, 28, 28, 1), types: tf.float32>
<TakeDataset shapes: (1, 28, 28, 1), types: tf.float32>


In [12]:
history_list = {}
#dataset = tf.data.Dataset.from_tensor_slices((x[0],x[0]))
dataset = tf.data.Dataset.zip((x[0],x[0]))

history = autoencoder.fit(
    dataset.batch(256),
    epochs = 20
)

history_list['base'] = history

Epoch 1/20


## #Chunk 9
### Variatioal Auto-Encoder Architecture

In [7]:
from tensorflow import keras
from tensorflow.keras import layers

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


In [8]:
latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var",activation="relu")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()


Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 14, 14, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 7, 7, 64)     18496       conv2d[0][0]                     
__________________________________________________________________________________________________
flatten (Flatten)               (None, 3136)         0           conv2d_1[0][0]                   
____________________________________________________________________________________________

In [9]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()


Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 2)]               0         
_________________________________________________________________
dense_1 (Dense)              (None, 3136)              9408      
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 28, 28, 32)        18464     
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 1)         289       
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_______________________________________________________

In [10]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }




In [11]:
vae_input = x_[0].batch(256)
vae_input
#vae_input.reshape(None, 28, 28, 1)

<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

In [12]:
vae_input = x_[0].batch(5500)

mymodel = VAE(encoder, decoder)
mymodel.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-6))
mymodel.fit(
    vae_input,
    epochs = 20
)
mymodel.summary()

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20


KeyboardInterrupt: 

In [103]:
history_list = {}

history = mymodel.fit(
    x[0],
    epochs = 20,
    batch_size=32

)

history_list['base'] = history

Epoch 1/20


ValueError: in user code:

    C:\Users\paulg\.conda\envs\corona\lib\site-packages\keras\engine\training.py:853 train_function  *
        return step_function(self, iterator)
    C:\Users\paulg\.conda\envs\corona\lib\site-packages\keras\engine\training.py:842 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    C:\Users\paulg\.conda\envs\corona\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1286 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    C:\Users\paulg\.conda\envs\corona\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2849 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    C:\Users\paulg\.conda\envs\corona\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3632 _call_for_each_replica
        return fn(*args, **kwargs)
    C:\Users\paulg\.conda\envs\corona\lib\site-packages\keras\engine\training.py:835 run_step  **
        outputs = model.train_step(data)
    <ipython-input-37-3815e43f8a8f>:22 train_step
        z_mean, z_log_var, z = self.encoder(data)
    C:\Users\paulg\.conda\envs\corona\lib\site-packages\keras\engine\base_layer.py:1020 __call__
        input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
    C:\Users\paulg\.conda\envs\corona\lib\site-packages\keras\engine\input_spec.py:269 assert_input_compatibility
        ', found shape=' + display_shape(x.shape))

    ValueError: Input 0 is incompatible with layer encoder: expected shape=(None, 28, 28, 1), found shape=(None, 28, 32)
