In [203]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Bidirectional, LSTM, Reshape, RepeatVector, TimeDistributed
from keras.layers import BatchNormalization, Activation
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
tf.compat.v1.enable_eager_execution() 

import matplotlib.pyplot as plt

import sys

import numpy as np
import pandas as pd

import os

from PIL import Image

# Load Data
Loading from preprocessed numpy array

In [204]:
def load_data():
    x_train = np.array(pd.read_csv("./data_stock/SP500_average.csv",).drop([0]).drop(columns=['Date']))
    return x_train
            

def split_time_series(t, arr) -> (np.array, np.array):
    a = []
    b = []
    for i in range(len(arr)-t):
        a.append(arr[i: i+t])
        b.append(arr[i+t])
    return (np.array(a), np.array(b))
    

def merge_time_series(arr1, arr2) -> tf.float64 :
    t1 = tf.cast(arr1, 'float64') if tf.is_tensor(arr1) else tf.convert_to_tensor(arr1, dtype='float64')
    t2 = tf.cast(arr2, 'float64') if tf.is_tensor(arr2) else tf.convert_to_tensor(arr2, dtype='float64')
    return tf.concat([t1, tf.expand_dims(t2, 1) ], axis=1)
                                
split = split_time_series(5, load_data())
merged = merge_time_series(*split)

print(split)
print(merged)

(array([[[4366.64, 4411.01, 4287.11, 4356.45],
        [4356.32, 4417.35, 4222.62, 4410.13],
        [4471.38, 4494.52, 4395.34, 4397.94],
        [4547.35, 4602.11, 4477.95, 4482.73],
        [4588.03, 4611.55, 4530.2 , 4532.76]],

       [[4356.32, 4417.35, 4222.62, 4410.13],
        [4471.38, 4494.52, 4395.34, 4397.94],
        [4547.35, 4602.11, 4477.95, 4482.73],
        [4588.03, 4611.55, 4530.2 , 4532.76],
        [4632.24, 4632.24, 4568.7 , 4577.11]],

       [[4471.38, 4494.52, 4395.34, 4397.94],
        [4547.35, 4602.11, 4477.95, 4482.73],
        [4588.03, 4611.55, 4530.2 , 4532.76],
        [4632.24, 4632.24, 4568.7 , 4577.11],
        [4637.99, 4665.13, 4614.75, 4662.85]],

       ...,

       [[2457.77, 2571.42, 2407.53, 2475.56],
        [2344.44, 2449.71, 2344.44, 2447.33],
        [2290.71, 2300.73, 2191.86, 2237.4 ],
        [2431.94, 2453.01, 2295.56, 2304.92],
        [2393.48, 2466.97, 2319.78, 2409.39]],

       [[2344.44, 2449.71, 2344.44, 2447.33],
        [229

# Creating GAN

In [216]:
class LSTMGAN():
    def __init__(self, t, f, data):
        # standardize data
        self.scaler = StandardScaler()
        self.scaler.fit(data)
        self.data = self.scaler.transform(data)
        
        self.time_series_len = t
        self.feature_len = f
        self.gen_shape = (self.time_series_len, self.feature_len)
        self.dis_shape = (self.time_series_len+1, self.feature_len)

        optimizer = Adam(0.0001, 0.4)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates song
        real_input = Input(shape=self.gen_shape)
        gen_output = self.generator(real_input)
        print(gen_output)
        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        valid = self.discriminator(merge_time_series(real_input, gen_output))

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(real_input, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        model = Sequential()
        model.add(Bidirectional(LSTM(128, return_sequences=True), input_shape=self.gen_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Bidirectional(LSTM(128)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(4))
        # model.summary()

        noise = Input(shape=self.gen_shape)
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()
        
        model.add(Bidirectional(LSTM(128, activation = 'relu', return_sequences=True), input_shape=self.dis_shape))\n",
        model.add(Dropout(0.4))
        model.add(TimeDistributed(Dense(128, activation = 'relu')))
        model.add(TimeDistributed(Dense(1, activation = 'linear')))
        #model.summary()

        img = Input(shape=self.dis_shape)
        validity = model(img)

        return Model(img, validity)
    

    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train_input, X_train_output) = split_time_series(self.time_series_len, self.data)

        # normalize

        # Adversarial ground truths
        valid = np.ones((batch_size,1,1))
        fake = np.zeros((batch_size,1,1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half of songs
            idx = np.random.randint(0, X_train_input.shape[0], batch_size)
            real_input= X_train_input[idx]
            real_output= X_train_output[idx]

            # Sample noise and generate a batch of new songs
            noise = np.random.normal(1000, 300, (batch_size,self.time_series_len,self.feature_len))
            gen_output = self.generator.predict(noise)
            real_series = merge_time_series(real_input,real_output)
            fake_series = merge_time_series(real_input,gen_output)
            print(self.scaler.inverse_transform(fake_series[0]))

            # Train the discriminator (real classified as ones and generated as zeros)
            d_loss_real = self.discriminator.train_on_batch(real_series, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_series, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator (wants discriminator to mistake songs as real)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save model
            if epoch % save_interval == 0:
                self.generator.save("LSTM_generator.h5")

# Model Summary
I couldn't train the model on this online notebook so I trained it locally for 1000 epochs and uploaded the h5 file.

In [None]:
lstmgan = LSTMGAN(5, 4, load_data())
lstmgan.train(epochs=1000, batch_size=100, save_interval=100)

KerasTensor(type_spec=TensorSpec(shape=(None, 4), dtype=tf.float32, name=None), name='model_121/sequential_96/dense_186/BiasAdd:0', description="created by layer 'model_121'")
[[3802.23       3820.96       3791.5        3809.84      ]
 [3801.62       3810.78       3776.51       3801.19      ]
 [3803.14       3817.86       3789.02       3799.61      ]
 [3815.05       3826.69       3783.6        3824.68      ]
 [3764.71       3811.55       3764.71       3803.79      ]
 [3997.23277184 3892.33843092 3899.78474329 3847.75398323]]
0 [D loss: 4.226187, acc.: 50.00%] [G loss: 0.730356]
[[3705.98       3712.39       3660.54       3672.82      ]
 [3683.05       3708.45       3678.83       3702.25      ]
 [3694.73       3697.41       3678.88       3691.96      ]
 [3670.94       3699.2        3670.94       3699.12      ]
 [3668.28       3682.73       3657.17       3666.72      ]
 [4013.35806915 3942.50089603 3925.05434944 3835.28793037]]
1 [D loss: 2.609390, acc.: 50.00%] [G loss: 0.666453]
[[3733

19 [D loss: 1.228730, acc.: 50.00%] [G loss: 0.338566]
[[4547.35       4602.11       4477.95       4482.73      ]
 [4588.03       4611.55       4530.2        4532.76      ]
 [4632.24       4632.24       4568.7        4577.11      ]
 [4637.99       4665.13       4614.75       4662.85      ]
 [4733.56       4744.13       4650.29       4659.03      ]
 [4478.42423585 4002.09274485 3790.81606029 3572.16514689]]
20 [D loss: 1.313081, acc.: 50.00%] [G loss: 0.331387]
[[4474.1        4513.33       4474.1        4509.37      ]
 [4493.75       4495.9        4468.99       4470.        ]
 [4490.45       4501.71       4485.66       4496.19      ]
 [4484.4        4492.81       4482.28       4486.23      ]
 [4450.29       4489.88       4450.29       4479.53      ]
 [4577.41870815 3977.4191002  3784.004524   3585.85842911]]
21 [D loss: 1.110367, acc.: 50.00%] [G loss: 0.322209]
[[4438.04       4463.12       4430.27       4455.48      ]
 [4406.75       4465.4        4406.75       4448.98      ]
 [4367.

40 [D loss: 1.006195, acc.: 51.00%] [G loss: 0.210835]
[[2787.89       2815.1        2775.95       2799.31      ]
 [2784.81       2785.54       2727.1        2736.56      ]
 [2845.62       2868.98       2820.43       2823.16      ]
 [2842.43       2879.22       2830.88       2874.56      ]
 [2799.34       2806.51       2764.32       2799.55      ]
 [5452.55440255 4470.89141806 4035.95525374 3568.22500235]]
41 [D loss: 0.981616, acc.: 51.25%] [G loss: 0.207375]
[[3764.71       3811.55       3764.71       3803.79      ]
 [3712.2        3783.04       3705.34       3748.14      ]
 [3698.02       3737.83       3695.07       3726.86      ]
 [3764.61       3769.99       3662.71       3700.65      ]
 [3733.27       3760.2        3726.88       3756.07      ]
 [5668.84339044 4374.84131416 3961.85258058 3677.55634842]]
42 [D loss: 0.964061, acc.: 51.33%] [G loss: 0.200292]
[[3566.82       3589.81       3552.77       3577.59      ]
 [3579.31       3581.23       3556.85       3557.54      ]
 [3559.

[[4096.11       4129.48       4095.51       4128.8       ]
 [4089.95       4098.19       4082.54       4097.17      ]
 [4074.29       4083.13       4068.31       4079.95      ]
 [4075.57       4086.23       4068.14       4073.94      ]
 [4034.44       4083.42       4034.44       4077.91      ]
 [6670.80333495 5063.72308503 4796.76837686 4303.60504306]]
61 [D loss: 0.931577, acc.: 45.83%] [G loss: 0.127887]
[[4707.25       4708.53       4670.87       4685.25      ]
 [4701.48       4714.92       4694.39       4701.7       ]
 [4699.26       4718.5        4681.32       4697.53      ]
 [4662.93       4683.         4662.59       4680.06      ]
 [4630.65       4663.46       4621.19       4660.57      ]
 [6681.27323388 5098.17053144 4983.406801   4294.26095209]]
62 [D loss: 0.933606, acc.: 42.42%] [G loss: 0.113561]
[[2782.46       2782.46       2721.17       2761.63      ]
 [2776.99       2818.57       2762.36       2789.82      ]
 [2685.         2760.75       2663.3        2749.98      ]
 [2

81 [D loss: 0.964747, acc.: 41.92%] [G loss: 0.069205]
[[2558.98       2631.8        2545.28       2626.65      ]
 [2555.87       2615.91       2520.02       2541.47      ]
 [2501.29       2637.01       2500.72       2630.07      ]
 [2457.77       2571.42       2407.53       2475.56      ]
 [2344.44       2449.71       2344.44       2447.33      ]
 [6184.39686669 5214.07254607 6413.53497373 5319.90989263]]
82 [D loss: 0.935656, acc.: 42.25%] [G loss: 0.069724]
[[4255.28       4257.16       4238.35       4246.59      ]
 [4248.31       4255.59       4234.07       4255.15      ]
 [4242.9        4248.38       4232.25       4247.44      ]
 [4228.56       4249.74       4220.34       4239.18      ]
 [4232.99       4237.09       4218.74       4219.55      ]
 [6137.80100757 5244.51217904 6406.34585418 5401.8976419 ]]
83 [D loss: 0.935893, acc.: 42.58%] [G loss: 0.071012]
[[3509.73       3514.77       3493.25       3500.31      ]
 [3494.69       3509.23       3484.32       3508.01      ]
 [3485.

101 [D loss: 1.003157, acc.: 42.42%] [G loss: 0.071425]
[[4173.4        4226.24       4173.4        4224.79      ]
 [4204.78       4204.78       4164.4        4166.45      ]
 [4220.37       4232.29       4196.05       4221.86      ]
 [4248.87       4251.89       4202.45       4223.7       ]
 [4255.28       4257.16       4238.35       4246.59      ]
 [5975.22903951 6107.20230996 6632.92503972 6339.16806411]]
102 [D loss: 1.008272, acc.: 42.33%] [G loss: 0.070322]
[[4204.78       4204.78       4164.4        4166.45      ]
 [4220.37       4232.29       4196.05       4221.86      ]
 [4248.87       4251.89       4202.45       4223.7       ]
 [4255.28       4257.16       4238.35       4246.59      ]
 [4248.31       4255.59       4234.07       4255.15      ]
 [5981.54047941 6188.08383259 6639.65607262 6329.68233978]]
103 [D loss: 0.974822, acc.: 42.83%] [G loss: 0.071463]
[[4206.05       4233.45       4206.05       4229.89      ]
 [4191.43       4204.39       4167.93       4192.85      ]
 [42

121 [D loss: 1.000439, acc.: 42.50%] [G loss: 0.072317]
[[4513.02       4529.9        4492.07       4493.28      ]
 [4518.09       4521.79       4493.95       4514.07      ]
 [4535.38       4535.38       4513.         4520.03      ]
 [4532.42       4541.45       4521.3        4535.43      ]
 [4534.48       4545.85       4524.66       4536.95      ]
 [5821.77282776 6637.60735641 6609.06488045 6961.68059077]]
122 [D loss: 1.005836, acc.: 42.75%] [G loss: 0.074198]
[[2501.29       2637.01       2500.72       2630.07      ]
 [2457.77       2571.42       2407.53       2475.56      ]
 [2344.44       2449.71       2344.44       2447.33      ]
 [2290.71       2300.73       2191.86       2237.4       ]
 [2431.94       2453.01       2295.56       2304.92      ]
 [5759.82683268 6633.76767335 6636.54445691 6957.00371395]]
123 [D loss: 0.995543, acc.: 42.50%] [G loss: 0.076266]
[[3438.5        3460.53       3415.34       3453.49      ]
 [3439.91       3464.86       3433.06       3435.56      ]
 [34

141 [D loss: 0.952602, acc.: 42.67%] [G loss: 0.093347]
[[3705.98       3712.39       3660.54       3672.82      ]
 [3683.05       3708.45       3678.83       3702.25      ]
 [3694.73       3697.41       3678.88       3691.96      ]
 [3670.94       3699.2        3670.94       3699.12      ]
 [3668.28       3682.73       3657.17       3666.72      ]
 [5755.14086392 7170.96308373 6896.82940955 7334.39698801]]
142 [D loss: 0.925138, acc.: 42.67%] [G loss: 0.095906]
[[3435.95       3444.21       3425.84       3443.62      ]
 [3418.09       3432.09       3413.13       3431.28      ]
 [3386.01       3399.96       3379.31       3397.16      ]
 [3360.48       3390.8        3354.69       3385.51      ]
 [3392.51       3399.54       3369.66       3374.85      ]
 [5769.54524716 7275.44085079 6944.64087195 7301.73468197]]
143 [D loss: 0.933901, acc.: 43.00%] [G loss: 0.097318]
[[2815.01       2844.24       2797.85       2842.74      ]
 [2869.09       2869.09       2821.61       2830.71      ]
 [29

161 [D loss: 0.921011, acc.: 42.58%] [G loss: 0.117937]
[[2784.81       2785.54       2727.1        2736.56      ]
 [2845.62       2868.98       2820.43       2823.16      ]
 [2842.43       2879.22       2830.88       2874.56      ]
 [2799.34       2806.51       2764.32       2799.55      ]
 [2795.64       2801.88       2761.54       2783.36      ]
 [5892.47824906 7890.0822856  7344.0095723  7424.26227084]]
162 [D loss: 0.914505, acc.: 42.67%] [G loss: 0.117415]
[[2457.77       2571.42       2407.53       2475.56      ]
 [2344.44       2449.71       2344.44       2447.33      ]
 [2290.71       2300.73       2191.86       2237.4       ]
 [2431.94       2453.01       2295.56       2304.92      ]
 [2393.48       2466.97       2319.78       2409.39      ]
 [5846.79196553 7845.12617914 7314.3082164  7337.59308897]]
163 [D loss: 0.912960, acc.: 42.83%] [G loss: 0.118062]
[[3723.03       3740.51       3723.03       3735.36      ]
 [3694.03       3703.82       3689.32       3703.06      ]
 [36

[[3764.61       3769.99       3662.71       3700.65      ]
 [3733.27       3760.2        3726.88       3756.07      ]
 [3736.19       3744.63       3730.21       3732.04      ]
 [3750.01       3756.12       3723.31       3727.04      ]
 [3723.03       3740.51       3723.03       3735.36      ]
 [6024.05126321 8462.03905393 7843.97534376 6861.87194802]]
182 [D loss: 0.861840, acc.: 50.33%] [G loss: 0.135963]
[[2878.26       2901.92       2876.48       2881.19      ]
 [2883.14       2891.11       2847.65       2848.42      ]
 [2868.88       2898.23       2863.55       2868.44      ]
 [2815.01       2844.24       2797.85       2842.74      ]
 [2869.09       2869.09       2821.61       2830.71      ]
 [6058.43336773 8466.42768049 7852.4173954  6825.96886498]]
183 [D loss: 0.878760, acc.: 50.75%] [G loss: 0.137218]
[[3015.65       3036.25       2969.75       3036.13      ]
 [3004.08       3021.72       2988.17       2991.77      ]
 [2948.05       2956.76       2933.59       2955.45      ]
 

201 [D loss: 0.833915, acc.: 50.75%] [G loss: 0.199638]
[[4139.76       4173.49       4139.76       4170.42      ]
 [4141.58       4151.69       4120.87       4124.66      ]
 [4130.1        4148.         4124.43       4141.59      ]
 [4124.71       4131.76       4114.82       4127.99      ]
 [4096.11       4129.48       4095.51       4128.8       ]
 [6207.26802368 8961.66294801 8338.61689365 6044.99304113]]
202 [D loss: 0.809970, acc.: 51.33%] [G loss: 0.177762]
[[4372.41       4386.68       4364.03       4384.63      ]
 [4329.38       4371.6        4329.38       4369.55      ]
 [4321.07       4330.88       4289.37       4320.82      ]
 [4351.01       4361.88       4329.79       4358.13      ]
 [4356.46       4356.46       4314.37       4343.54      ]
 [6173.65725691 8949.11349772 8331.27300709 6013.64382882]]
203 [D loss: 0.824857, acc.: 50.58%] [G loss: 0.206693]
[[3698.08       3698.26       3676.16       3687.26      ]
 [3684.28       3702.9        3636.48       3694.92      ]
 [37

[[4628.75       4672.95       4625.26       4655.27      ]
 [4664.63       4664.63       4585.43       4594.62      ]
 [4675.78       4702.87       4659.89       4701.46      ]
 [4678.48       4699.39       4652.66       4690.7       ]
 [4712.         4743.83       4682.17       4682.94      ]
 [6563.31841199 9134.88369163 8776.31602399 5286.31335753]]
222 [D loss: 0.776664, acc.: 51.00%] [G loss: 0.355694]
[[3694.73       3697.41       3678.88       3691.96      ]
 [3670.94       3699.2        3670.94       3699.12      ]
 [3668.28       3682.73       3657.17       3666.72      ]
 [3653.78       3670.96       3644.84       3669.01      ]
 [3645.87       3678.45       3645.87       3662.45      ]
 [6485.48197686 9049.89823933 8724.85052157 5263.56692975]]
223 [D loss: 0.763166, acc.: 51.33%] [G loss: 0.311768]
[[3224.29       3258.61       3215.16       3251.84      ]
 [3224.21       3233.52       3205.65       3224.73      ]
 [3208.36       3220.39       3198.59       3215.57      ]
 

[[4655.24       4688.47       4650.77       4682.85      ]
 [4659.39       4664.55       4648.31       4649.27      ]
 [4670.26       4684.85       4630.86       4646.71      ]
 [4707.25       4708.53       4670.87       4685.25      ]
 [4701.48       4714.92       4694.39       4701.7       ]
 [4939.55464985 7723.87715096 9800.67944673 6193.76442815]]
242 [D loss: 0.754555, acc.: 50.92%] [G loss: 0.263901]
[[4640.25       4646.02       4560.         4567.        ]
 [4628.75       4672.95       4625.26       4655.27      ]
 [4664.63       4664.63       4585.43       4594.62      ]
 [4675.78       4702.87       4659.89       4701.46      ]
 [4678.48       4699.39       4652.66       4690.7       ]
 [4857.14820398 7652.86244226 9875.59436923 6316.87765534]]
243 [D loss: 0.764427, acc.: 50.67%] [G loss: 0.267917]
[[4697.66       4707.95       4662.74       4677.03      ]
 [4693.39       4725.01       4671.26       4696.05      ]
 [4787.99       4797.7        4699.44       4700.58      ]
 

262 [D loss: 0.716761, acc.: 50.50%] [G loss: 0.307214]
[[ 4438.04        4463.12        4430.27        4455.48      ]
 [ 4406.75        4465.4         4406.75        4448.98      ]
 [ 4367.43        4416.75        4367.43        4395.64      ]
 [ 4374.45        4394.87        4347.96        4354.19      ]
 [ 4402.95        4402.95        4305.91        4357.73      ]
 [ 4103.93003321  6901.98398113 10304.4155554   6765.77494433]]
263 [D loss: 0.719907, acc.: 50.17%] [G loss: 0.310629]
[[ 4228.29        4236.39        4188.13        4188.43      ]
 [ 4210.34        4238.04        4201.64        4232.6       ]
 [ 4169.14        4202.7         4147.33        4201.62      ]
 [ 4177.06        4187.72        4160.94        4167.59      ]
 [ 4179.04        4179.04        4128.59        4164.66      ]
 [ 4070.33932196  6869.68687994 10260.77362846  6714.57090266]]
264 [D loss: 0.719904, acc.: 49.42%] [G loss: 0.315076]
[[ 4233.81        4236.74        4208.41        4227.26      ]
 [ 4229.34 

[[ 3293.59        3304.93        3233.94        3269.96      ]
 [ 3277.17        3341.05        3259.82        3310.11      ]
 [ 3342.48        3342.48        3268.89        3271.03      ]
 [ 3403.15        3409.51        3388.71        3390.68      ]
 [ 3441.42        3441.42        3364.86        3400.97      ]
 [ 3424.24538736  6412.73129213 10610.5216017   6716.06843204]]
282 [D loss: 0.697195, acc.: 51.25%] [G loss: 0.467735]
[[ 4159.18        4159.18        4118.38        4134.94      ]
 [ 4179.8         4180.81        4150.47        4163.26      ]
 [ 4174.14        4191.31        4170.75        4185.47      ]
 [ 4139.76        4173.49        4139.76        4170.42      ]
 [ 4141.58        4151.69        4120.87        4124.66      ]
 [ 3422.31244731  6457.96143213 10653.50776636  6701.69626494]]
283 [D loss: 0.706615, acc.: 51.83%] [G loss: 0.484673]
[[ 2915.46        2944.25        2903.44        2930.32      ]
 [ 2908.83        2932.16        2902.88        2929.8       ]
 [ 2

[[ 4553.69        4572.62        4537.36        4566.48      ]
 [ 4546.12        4559.67        4524.          4544.9       ]
 [ 4532.24        4551.44        4526.89        4549.78      ]
 [ 4524.42        4540.87        4524.4         4536.19      ]
 [ 4497.34        4520.4         4496.41        4519.63      ]
 [ 2466.78699865  6320.6960497  10199.4225688   6155.39108295]]
301 [D loss: 0.683479, acc.: 53.08%] [G loss: 2.570824]
[[ 2787.89        2815.1         2775.95        2799.31      ]
 [ 2784.81        2785.54        2727.1         2736.56      ]
 [ 2845.62        2868.98        2820.43        2823.16      ]
 [ 2842.43        2879.22        2830.88        2874.56      ]
 [ 2799.34        2806.51        2764.32        2799.55      ]
 [ 2447.80341334  6299.35225572 10173.00978974  6146.92497185]]
302 [D loss: 0.690821, acc.: 52.17%] [G loss: 2.570824]
[[ 3384.56        3426.26        3384.56        3419.45      ]
 [ 3408.74        3431.56        3354.54        3360.95      ]
 [ 3

[[3814.98       3823.6        3792.86       3795.54      ]
 [3802.23       3820.96       3791.5        3809.84      ]
 [3801.62       3810.78       3776.51       3801.19      ]
 [3803.14       3817.86       3789.02       3799.61      ]
 [3815.05       3826.69       3783.6        3824.68      ]
 [1687.41659937 3516.19460204 8987.09210788 5643.90852906]]
321 [D loss: 0.684192, acc.: 52.83%] [G loss: 2.596533]
[[2555.87       2615.91       2520.02       2541.47      ]
 [2501.29       2637.01       2500.72       2630.07      ]
 [2457.77       2571.42       2407.53       2475.56      ]
 [2344.44       2449.71       2344.44       2447.33      ]
 [2290.71       2300.73       2191.86       2237.4       ]
 [1690.31532929 3439.95101659 8975.99327205 5577.4202823 ]]
322 [D loss: 0.688906, acc.: 52.25%] [G loss: 2.570824]
[[4229.34       4232.34       4215.66       4226.52      ]
 [4206.05       4233.45       4206.05       4229.89      ]
 [4191.43       4204.39       4167.93       4192.85      ]
 

[[4435.79       4445.21       4430.03       4436.75      ]
 [4437.77       4439.39       4424.74       4432.35      ]
 [4429.07       4440.82       4429.07       4436.52      ]
 [4408.86       4429.76       4408.86       4429.1       ]
 [4415.95       4416.17       4400.23       4402.66      ]
 [2398.60574154 2177.97047624 5921.51588114 1482.09099116]]
341 [D loss: 0.693561, acc.: 52.58%] [G loss: 0.352394]
[[2344.44       2449.71       2344.44       2447.33      ]
 [2290.71       2300.73       2191.86       2237.4       ]
 [2431.94       2453.01       2295.56       2304.92      ]
 [2393.48       2466.97       2319.78       2409.39      ]
 [2436.5        2453.57       2280.52       2398.1       ]
 [2417.09520009 2305.01700806 5681.85237373 1511.1606963 ]]
342 [D loss: 0.694321, acc.: 53.00%] [G loss: 0.349159]
[[4675.78       4702.87       4659.89       4701.46      ]
 [4678.48       4699.39       4652.66       4690.7       ]
 [4712.         4743.83       4682.17       4682.94      ]
 

[[4075.57       4086.23       4068.14       4073.94      ]
 [4034.44       4083.42       4034.44       4077.91      ]
 [3992.78       4020.63       3992.78       4019.87      ]
 [3967.25       3994.41       3966.98       3972.89      ]
 [3963.34       3968.01       3944.35       3958.55      ]
 [3130.48232396 2927.16522335 2973.37192773 1531.46346181]]
361 [D loss: 0.750771, acc.: 45.00%] [G loss: 0.107636]
[[3412.56       3425.55       3329.25       3339.19      ]
 [3369.82       3424.77       3366.84       3398.96      ]
 [3371.88       3379.97       3329.27       3331.84      ]
 [3453.6        3479.15       3349.63       3426.96      ]
 [3564.74       3564.85       3427.41       3455.06      ]
 [3124.95685444 3066.16383865 3055.17671905 1692.81835743]]
362 [D loss: 0.769951, acc.: 43.92%] [G loss: 0.104737]
[[3844.39       3881.06       3819.25       3821.35      ]
 [3793.58       3851.69       3730.19       3841.94      ]
 [3818.53       3843.67       3723.34       3768.47      ]
 

381 [D loss: 0.707853, acc.: 51.33%] [G loss: 0.250856]
[[2854.65       2887.72       2852.89       2878.48      ]
 [2812.64       2842.71       2791.76       2836.74      ]
 [2810.42       2844.9        2794.26       2797.8       ]
 [2787.89       2815.1        2775.95       2799.31      ]
 [2784.81       2785.54       2727.1        2736.56      ]
 [4086.5691065  2969.46790813 2050.77764761 1705.61878   ]]
382 [D loss: 0.720619, acc.: 51.42%] [G loss: 0.234740]
[[4419.54       4419.54       4346.33       4352.63      ]
 [4442.12       4457.3        4436.19       4443.11      ]
 [4438.04       4463.12       4430.27       4455.48      ]
 [4406.75       4465.4        4406.75       4448.98      ]
 [4367.43       4416.75       4367.43       4395.64      ]
 [4119.89446359 2956.12754389 2015.21216116 1835.42533157]]
383 [D loss: 0.710943, acc.: 51.58%] [G loss: 0.272825]
[[3357.38       3362.27       3292.4        3319.47      ]
 [3346.86       3375.17       3328.82       3357.01      ]
 [34

[[3842.51       3914.5        3842.51       3901.82      ]
 [3839.66       3861.08       3789.54       3811.15      ]
 [3915.8        3925.02       3814.04       3829.34      ]
 [3873.71       3928.65       3859.6        3925.43      ]
 [3857.07       3895.98       3805.59       3881.37      ]
 [3199.28278366 2259.46725965 2717.33278434 2247.47164109]]
402 [D loss: 0.702845, acc.: 51.00%] [G loss: 0.288111]
[[3764.61       3769.99       3662.71       3700.65      ]
 [3733.27       3760.2        3726.88       3756.07      ]
 [3736.19       3744.63       3730.21       3732.04      ]
 [3750.01       3756.12       3723.31       3727.04      ]
 [3723.03       3740.51       3723.03       3735.36      ]
 [3169.94415449 2258.97616118 2683.88365526 2347.77995016]]
403 [D loss: 0.697517, acc.: 51.50%] [G loss: 0.299693]
[[3270.45       3272.17       3220.26       3271.12      ]
 [3231.76       3250.92       3204.13       3246.22      ]
 [3227.22       3264.74       3227.22       3258.44      ]
 

[[4632.24       4632.24       4568.7        4577.11      ]
 [4637.99       4665.13       4614.75       4662.85      ]
 [4733.56       4744.13       4650.29       4659.03      ]
 [4728.59       4748.83       4706.71       4726.35      ]
 [4669.14       4714.13       4638.27       4713.07      ]
 [2500.35011748 2324.94447521 2292.8492171  3004.068157  ]]
423 [D loss: 0.778246, acc.: 42.58%] [G loss: 0.208530]
[[4733.99       4791.49       4733.99       4791.19      ]
 [4703.96       4740.74       4703.96       4725.79      ]
 [4650.36       4697.67       4645.53       4696.56      ]
 [4594.96       4651.14       4583.16       4649.23      ]
 [4587.9        4587.9        4531.1        4568.02      ]
 [2557.07233978 2439.17833263 2432.60529871 3030.92456075]]
424 [D loss: 0.776032, acc.: 43.33%] [G loss: 0.153335]
[[3892.59       3915.77       3892.59       3915.59      ]
 [3878.3        3894.56       3874.93       3886.83      ]
 [3836.66       3872.42       3836.66       3871.74      ]
 

443 [D loss: 0.744475, acc.: 43.92%] [G loss: 0.133983]
[[3515.47       3527.94       3480.55       3488.67      ]
 [3534.01       3534.01       3500.86       3511.93      ]
 [3500.02       3549.85       3499.61       3534.22      ]
 [3459.67       3482.34       3458.07       3477.13      ]
 [3434.28       3447.28       3428.15       3446.83      ]
 [2130.81198048 2440.2205657  2409.67551697 3114.83161656]]
444 [D loss: 0.751505, acc.: 43.58%] [G loss: 0.161436]
[[2993.76       3079.76       2965.66       3066.59      ]
 [3071.04       3088.42       2984.47       3041.31      ]
 [3123.53       3123.53       2999.49       3002.1       ]
 [3213.42       3223.27       3181.49       3190.14      ]
 [3213.32       3222.71       3193.11       3207.18      ]
 [2126.14292381 2499.26106081 2327.23101712 3178.46309458]]
445 [D loss: 0.750499, acc.: 43.33%] [G loss: 0.136197]
[[4640.25       4646.02       4560.         4567.        ]
 [4628.75       4672.95       4625.26       4655.27      ]
 [46

Loading pretrained model

Installinging Mido Library

# Generating Melody
Generating random input and letting model predict output

In [143]:

random = np.random.normal(0,1,(1,5,4))

predict = lstmgan.generator(random)

print(predict)

tf.Tensor([[ 0.08106427 -6.513324   -8.422126   -9.677593  ]], shape=(1, 4), dtype=float32)


# Back to MIDI
Save generated melody back to a .mid file

In [None]:
midler = MidiFile()
track = MidiTrack()
midler.tracks.append(track)
track.append(Message('program_change', program=2, time=0))
for x in range(16):
    track.append(Message('note_on', note=int(predict[0][x][0]), velocity=64, time=20))
    track.append(Message('note_off', note=int(predict[0][x][0]), velocity=64, time=20))
    midler.save('new_song.mid')