In [2]:
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import gym
import time
import cv2
from tqdm import tqdm
from queue import Queue
import copy
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [3]:
class ExperienceReplayBatch():
    
    def __init__(self, buffer_size):
        self.buffer_size = buffer_size
        self.replay_buf = {}
        self.replay_index = 0
        self.current_buf_size = 0
        
    def store_replay(self, replay):
        if self.replay_index == self.buffer_size:
            self.replay_index = 0
        self.replay_buf[self.replay_index] = copy.deepcopy(replay)
        self.replay_index += 1
        self.current_buf_size = max(self.replay_index, self.current_buf_size)
        
    def getRandomBatch(self, batch_size):
        batch = np.random.permutation(self.current_buf_size)[:batch_size]
        replay_batch = []
        for b in batch:
            replay_batch += [self.replay_buf[b]]
        return replay_batch
    
    def reset(self):
        del(self.replay_buf)
        self.replay_buf = {}
        self.replay_index = 0
        self.current_buf_size = 0

In [4]:
def processImg(state_images):
    images = []
    for img in state_images:
        images.append(cv2.resize(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), dsize = (84, 110))[101 - 84:-9])
    return np.dstack(images)

In [5]:
class DQN():
    
    def __init__(self, state_shape, actions, lrate = 0.00025, 
                 momentum = 0.95, discount = 0.95, save_file = None, load_file = None):
        tf.reset_default_graph()
        self._sess = tf.Session(config = config)
        self.state_shape = state_shape
        self.actions = actions
        self.save_file = save_file
        self.load_file = load_file
        self.discount = discount
        self._input = tf.placeholder(tf.float32, [None, state_shape[0], state_shape[1], state_shape[2]])
        self._action_value = tf.placeholder(tf.float32, [None, actions])
        self._initLayers()
        self.pred = self._predict()
        self.loss = keras.losses.MSE(self._action_value, self.pred)
        self.train_step = tf.train.AdamOptimizer(lrate).minimize(self.loss)
        #print(tf.global_variables())
        if load_file is None:
            self._sess.run(tf.global_variables_initializer())
        else:
            self.load()
            
    def _initLayers(self):
        self.conv1 = keras.layers.Conv2D(16, 8, strides = (4, 4), padding = 'valid', 
                                         activation = 'relu', 
                                         input_shape = self.state_shape,
                                         kernel_initializer = 'VarianceScaling' 
                                        )
        self.conv2 = keras.layers.Conv2D(32, 4, strides = (2, 2), 
                                         padding = 'valid', 
                                         activation = 'relu',
                                         kernel_initializer = 'VarianceScaling'
                                        )
        self.flatten = keras.layers.Flatten()
        self.dense1 = keras.layers.Dense(256, activation = 'relu',
                                         kernel_initializer = 'VarianceScaling',
                                         bias_initializer = 'VarianceScaling'
                                        )
        self.dense2 = keras.layers.Dense(self.actions,
                                         kernel_initializer = 'VarianceScaling',
                                         bias_initializer = 'VarianceScaling'
                                        )
        
    def _predict(self):
        x = self.flatten(self.conv2(self.conv1(self._input)))
        x = self.dense2(self.dense1(x))
        return x
        
    def predict(self, state_data):
        state_data_float = state_data/255
        assert state_data[0].shape == self.state_shape
        pred = self._sess.run([self.pred], feed_dict = {
            self._input: state_data_float
        })
        return pred
    
    def train(self, replay_batch):
        next_state_batch = np.array([i[-2] for i in replay_batch], dtype = np.float32)
        current_state_batch = np.array([i[0] for i in replay_batch], dtype = np.float32)
        sess = self._sess
        next_state_pred = sess.run([self.pred], feed_dict = {
            self._input: next_state_batch/255
        })
        current_state_pred = sess.run([self.pred], feed_dict = {
            self._input: current_state_batch/255
        })[0]
        
        #print(current_state_pred.shape)
        best_actions = [np.max(actions) for actions in next_state_pred[0]]
        #print(len(best_actions))
        for i, rb in enumerate(replay_batch):
            if not rb[-1]:
                current_state_pred[i,rb[1]] = rb[2] + self.discount * best_actions[i]
            else:
                current_state_pred[i,rb[1]] = rb[2]
        _, loss = sess.run([self.train_step, self.loss], feed_dict = {
            self._action_value: current_state_pred,
            self._input: current_state_batch/255
        })
        #print(np.sum(loss))
        return np.sum(loss)/len(loss)
    
    def save(self):
        assert self.save_file != None
        saver = tf.train.Saver(var_list = tf.global_variables())
        saver.save(self._sess, self.save_file)
        
    def load(self):
        assert self.load_file != None
        loader = tf.train.Saver()
        loader.restore(self._sess, self.load_file)
        
    def __del__(self):
        self._sess.close()

In [6]:
def step(env, action, framecount = 4):
    reward = 0
    states = []
    for i in range(framecount):
        state, r, done, p = env.step(action)
        states += [state]
        reward += r
        if done:
            if len(states) < 4:
                for i in range(4 - len(states)):
                    states += [state]
            #r = -1
            break
    processed_images = processImg(states).astype(np.uint)
    return [action, reward, processed_images, done]

In [7]:
def play(env, dqn, fire_interval = 10, render = False):
    state = env.reset()
    init_state, a, r, p = env.step(1)
    state = processImg([init_state, init_state, init_state, init_state])
    done = False
    total_reward = 0
    fi = fire_interval
    while not done:
        pred = dqn.predict(np.array([state]))[0]
        action = np.argmax(pred)
        action, reward, state, done = step(env, action)
        total_reward += reward
        if render:
            time.sleep(0.05)
            env.render()
        fi -= 1
        if fi == 0:
            env.step(1)
            fi = fire_interval
    return total_reward

In [8]:
def dqLearn(env, exp_buf, dqn, eps, epsdecay = 0.9994, episodes = 10000, 
            train_len = 200, min_buf_size = 2500, buf_clear_thres = 4000,
            render_interval = 10, batch_size = 48):
    
    for ep in range(episodes):
        
        print('Episode ', ep, "Current Replay Size", exp_buf.current_buf_size)
        env.reset()
        init_state, a, r, p = env.step(1)
        prev_state = processImg([init_state, init_state, init_state, init_state])
        done = False
        
        while not done:
            next_action_choice = np.random.choice([0,1], p = [eps, 1-eps])
            if next_action_choice == 0:
                next_action = np.random.choice([0,1,2,3])
            else:
                next_action = np.argmax(dqn.predict(np.array([prev_state]))[0])
            replay = step(env, next_action)
            done = replay[-1] 
            replay = [prev_state] + replay
            exp_buf.store_replay(replay)
            prev_state = replay[-2]
            
        if exp_buf.current_buf_size >= min_buf_size:
            avg_loss = 0
            for _ in tqdm(range(train_len)):
                train_batch = exp_buf.getRandomBatch(batch_size)
                avg_loss += dqn.train(train_batch)
            #eps = max(0.1, epsdecay * eps)
            print('Average Loss', avg_loss/train_len)
            dqn.save()
            if render_interval !=0 and ep % render_interval == 0:
                print("Reward ", play(env, dqn, render = False))
#             else:
#                 print("Reward ", play(env, dqn))
                
        if exp_buf.current_buf_size >= buf_clear_thres:
            exp_buf.reset()
        
    

In [9]:
benv = gym.make('Breakout-v4')
exp_buf = ExperienceReplayBatch(10000)
dqn = DQN((84, 84, 4), 4, save_file = 'my_model5', load_file = 'my_model4')

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from my_model4


In [None]:
dqLearn(benv, exp_buf, dqn, 0.3, render_interval=10)
#play(benv, dqn)
benv.close()

Episode  0 Current Replay Size 0
Episode  1 Current Replay Size 68
Episode  2 Current Replay Size 130
Episode  3 Current Replay Size 220
Episode  4 Current Replay Size 362
Episode  5 Current Replay Size 478
Episode  6 Current Replay Size 630
Episode  7 Current Replay Size 729
Episode  8 Current Replay Size 829
Episode  9 Current Replay Size 978
Episode  10 Current Replay Size 1139
Episode  11 Current Replay Size 1236
Episode  12 Current Replay Size 1326
Episode  13 Current Replay Size 1422
Episode  14 Current Replay Size 1567
Episode  15 Current Replay Size 1700
Episode  16 Current Replay Size 1863
Episode  17 Current Replay Size 1970
Episode  18 Current Replay Size 2065
Episode  19 Current Replay Size 2202
Episode  20 Current Replay Size 2299
Episode  21 Current Replay Size 2433


100%|██████████| 200/200 [00:07<00:00, 27.24it/s]


Average Loss 0.007644126941449943
Episode  22 Current Replay Size 2500


100%|██████████| 200/200 [00:07<00:00, 27.84it/s]


Average Loss 0.004013495838347202
Episode  23 Current Replay Size 2577


100%|██████████| 200/200 [00:07<00:00, 28.43it/s]


Average Loss 0.0036057578178588286
Episode  24 Current Replay Size 2704


100%|██████████| 200/200 [00:07<00:00, 26.24it/s]


Average Loss 0.002887439000575492
Episode  25 Current Replay Size 2823


100%|██████████| 200/200 [00:07<00:00, 27.41it/s]


Average Loss 0.0028237630314348894
Episode  26 Current Replay Size 2997


100%|██████████| 200/200 [00:07<00:00, 26.89it/s]


Average Loss 0.0029915848905996746
Episode  27 Current Replay Size 3083


100%|██████████| 200/200 [00:07<00:00, 26.06it/s]


Average Loss 0.0029230480811869098
Episode  28 Current Replay Size 3226


100%|██████████| 200/200 [00:07<00:00, 27.28it/s]


Average Loss 0.0027226501610130073
Episode  29 Current Replay Size 3344


100%|██████████| 200/200 [00:07<00:00, 27.19it/s]


Average Loss 0.002720316539440923
Episode  30 Current Replay Size 3453


100%|██████████| 200/200 [00:07<00:00, 26.70it/s]


Average Loss 0.0026658537506591535
Reward  4.0
Episode  31 Current Replay Size 3556


100%|██████████| 200/200 [00:07<00:00, 26.37it/s]


Average Loss 0.0026766888929220536
Episode  32 Current Replay Size 3654


100%|██████████| 200/200 [00:07<00:00, 26.41it/s]


Average Loss 0.0023225492603766414
Episode  33 Current Replay Size 3762


100%|██████████| 200/200 [00:07<00:00, 28.16it/s]


Average Loss 0.0023933885421138267
Episode  34 Current Replay Size 3844


100%|██████████| 200/200 [00:08<00:00, 24.34it/s]


Average Loss 0.0024818979909953967
Episode  35 Current Replay Size 3969


100%|██████████| 200/200 [00:08<00:00, 22.50it/s]


Average Loss 0.002384245657982925
Episode  36 Current Replay Size 0
Episode  37 Current Replay Size 111
Episode  38 Current Replay Size 242
Episode  39 Current Replay Size 321
Episode  40 Current Replay Size 436
Episode  41 Current Replay Size 574
Episode  42 Current Replay Size 652
Episode  43 Current Replay Size 786
Episode  44 Current Replay Size 940
Episode  45 Current Replay Size 1060
Episode  46 Current Replay Size 1186
Episode  47 Current Replay Size 1318
Episode  48 Current Replay Size 1402
Episode  49 Current Replay Size 1476
Episode  50 Current Replay Size 1561
Episode  51 Current Replay Size 1638
Episode  52 Current Replay Size 1698
Episode  53 Current Replay Size 1817
Episode  54 Current Replay Size 1909
Episode  55 Current Replay Size 2008
Episode  56 Current Replay Size 2181
Episode  57 Current Replay Size 2272
Episode  58 Current Replay Size 2424


100%|██████████| 200/200 [00:07<00:00, 26.56it/s]


Average Loss 0.005897907728794964
Episode  59 Current Replay Size 2500


100%|██████████| 200/200 [00:07<00:00, 25.27it/s]


Average Loss 0.003591580931097273
Episode  60 Current Replay Size 2664


100%|██████████| 200/200 [00:07<00:00, 28.06it/s]


Average Loss 0.003218856977764516
Reward  9.0
Episode  61 Current Replay Size 2850


100%|██████████| 200/200 [00:07<00:00, 27.66it/s]


Average Loss 0.0030976863473188127
Episode  62 Current Replay Size 3015


100%|██████████| 200/200 [00:07<00:00, 27.21it/s]


Average Loss 0.002695317474038651
Episode  63 Current Replay Size 3143


100%|██████████| 200/200 [00:07<00:00, 26.62it/s]


Average Loss 0.00264393959970524
Episode  64 Current Replay Size 3300


100%|██████████| 200/200 [00:07<00:00, 26.38it/s]


Average Loss 0.0023633181264934443
Episode  65 Current Replay Size 3404


100%|██████████| 200/200 [00:07<00:00, 27.33it/s]


Average Loss 0.0025546838855370884
Episode  66 Current Replay Size 3496


100%|██████████| 200/200 [00:07<00:00, 26.58it/s]


Average Loss 0.0021928230661433195
Episode  67 Current Replay Size 3597


100%|██████████| 200/200 [00:07<00:00, 26.72it/s]


Average Loss 0.002376447794182848
Episode  68 Current Replay Size 3660


100%|██████████| 200/200 [00:07<00:00, 26.21it/s]


Average Loss 0.002225697837226713
Episode  69 Current Replay Size 3786


100%|██████████| 200/200 [00:07<00:00, 27.20it/s]


Average Loss 0.0022069939873957394
Episode  70 Current Replay Size 3884


100%|██████████| 200/200 [00:07<00:00, 27.59it/s]


Average Loss 0.0019707330634507027
Reward  13.0
Episode  71 Current Replay Size 3996


100%|██████████| 200/200 [00:07<00:00, 28.10it/s]


Average Loss 0.002638959671991568
Episode  72 Current Replay Size 0
Episode  73 Current Replay Size 69
Episode  74 Current Replay Size 199
Episode  75 Current Replay Size 350
Episode  76 Current Replay Size 421
Episode  77 Current Replay Size 489
Episode  78 Current Replay Size 581
Episode  79 Current Replay Size 689
Episode  80 Current Replay Size 809
Episode  81 Current Replay Size 958
Episode  82 Current Replay Size 1047
Episode  83 Current Replay Size 1155
Episode  84 Current Replay Size 1285
Episode  85 Current Replay Size 1493
Episode  86 Current Replay Size 1561
Episode  87 Current Replay Size 1722
Episode  88 Current Replay Size 1825
Episode  89 Current Replay Size 1886
Episode  90 Current Replay Size 1980
Episode  91 Current Replay Size 2066
Episode  92 Current Replay Size 2152
Episode  93 Current Replay Size 2288
Episode  94 Current Replay Size 2413
Episode  95 Current Replay Size 2481


100%|██████████| 200/200 [00:06<00:00, 28.74it/s]


Average Loss 0.006751320095111926
Episode  96 Current Replay Size 2608


100%|██████████| 200/200 [00:06<00:00, 28.61it/s]


Average Loss 0.004095372005055348
Episode  97 Current Replay Size 2712


100%|██████████| 200/200 [00:06<00:00, 28.68it/s]


Average Loss 0.0038453453592956054
Episode  98 Current Replay Size 2830


100%|██████████| 200/200 [00:07<00:00, 28.11it/s]


Average Loss 0.0034623010247014445
Episode  99 Current Replay Size 2931


100%|██████████| 200/200 [00:07<00:00, 27.18it/s]


Average Loss 0.0031208518349255155
Episode  100 Current Replay Size 3041


100%|██████████| 200/200 [00:07<00:00, 26.25it/s]


Average Loss 0.003584171658148987
Reward  4.0
Episode  101 Current Replay Size 3180


100%|██████████| 200/200 [00:07<00:00, 27.62it/s]


Average Loss 0.0031397920843058568
Episode  102 Current Replay Size 3274


100%|██████████| 200/200 [00:07<00:00, 28.01it/s]


Average Loss 0.0034436863362013025
Episode  103 Current Replay Size 3394


100%|██████████| 200/200 [00:07<00:00, 27.07it/s]


Average Loss 0.003143645053884633
Episode  104 Current Replay Size 3534


100%|██████████| 200/200 [00:07<00:00, 27.10it/s]


Average Loss 0.0031910801209354155
Episode  105 Current Replay Size 3656


100%|██████████| 200/200 [00:07<00:00, 27.69it/s]


Average Loss 0.003964872657088564
Episode  106 Current Replay Size 3758


100%|██████████| 200/200 [00:07<00:00, 27.34it/s]


Average Loss 0.003737760594813156
Episode  107 Current Replay Size 3891


100%|██████████| 200/200 [00:07<00:00, 27.45it/s]


Average Loss 0.003236620596920449
Episode  108 Current Replay Size 3993


100%|██████████| 200/200 [00:07<00:00, 27.12it/s]


Average Loss 0.0031662595784291638
Episode  109 Current Replay Size 0
Episode  110 Current Replay Size 100
Episode  111 Current Replay Size 259
Episode  112 Current Replay Size 329
Episode  113 Current Replay Size 400
Episode  114 Current Replay Size 527
Episode  115 Current Replay Size 603
Episode  116 Current Replay Size 711
Episode  117 Current Replay Size 863
Episode  118 Current Replay Size 974
Episode  119 Current Replay Size 1040
Episode  120 Current Replay Size 1131
Episode  121 Current Replay Size 1239
Episode  122 Current Replay Size 1355
Episode  123 Current Replay Size 1485
Episode  124 Current Replay Size 1577
Episode  125 Current Replay Size 1690
Episode  126 Current Replay Size 1787
Episode  127 Current Replay Size 1873
Episode  128 Current Replay Size 2006
Episode  129 Current Replay Size 2150
Episode  130 Current Replay Size 2260
Episode  131 Current Replay Size 2358
Episode  132 Current Replay Size 2444


100%|██████████| 200/200 [00:06<00:00, 28.86it/s]


Average Loss 0.007786204806373764
Episode  133 Current Replay Size 2560


100%|██████████| 200/200 [00:07<00:00, 28.14it/s]


Average Loss 0.004704908001391836
Episode  134 Current Replay Size 2684


100%|██████████| 200/200 [00:08<00:00, 24.36it/s]


Average Loss 0.003810416118552286
Episode  135 Current Replay Size 2793


100%|██████████| 200/200 [00:08<00:00, 23.99it/s]


Average Loss 0.0033454056316986696
Episode  136 Current Replay Size 2860


100%|██████████| 200/200 [00:07<00:00, 27.01it/s]


Average Loss 0.0033483900909777715
Episode  137 Current Replay Size 2935


100%|██████████| 200/200 [00:07<00:00, 25.24it/s]


Average Loss 0.0031082432573505984
Episode  138 Current Replay Size 3107


100%|██████████| 200/200 [00:09<00:00, 22.11it/s]


Average Loss 0.0032021496250914994
Episode  139 Current Replay Size 3267


100%|██████████| 200/200 [00:07<00:00, 26.18it/s]


Average Loss 0.0027652905931851517
Episode  140 Current Replay Size 3367


100%|██████████| 200/200 [00:08<00:00, 23.81it/s]


Average Loss 0.0027689226930184907
Reward  4.0
Episode  141 Current Replay Size 3483


 18%|█▊        | 36/200 [00:01<00:06, 26.87it/s]

In [32]:
res = step(benv, 2)

In [12]:
benv.get_action_meanings()

['NOOP', 'FIRE', 'RIGHT', 'LEFT']

In [41]:
rewards = np.array([1,1,1,1])
actions = np.array([1,2,3,0])
dqn = DQN((84, 84, 2), 4)


[<tf.Variable 'conv2d/kernel:0' shape=(8, 8, 2, 16) dtype=float32>, <tf.Variable 'conv2d/bias:0' shape=(16,) dtype=float32>, <tf.Variable 'conv2d_1/kernel:0' shape=(4, 4, 16, 32) dtype=float32>, <tf.Variable 'conv2d_1/bias:0' shape=(32,) dtype=float32>, <tf.Variable 'dense/kernel:0' shape=(2592, 256) dtype=float32>, <tf.Variable 'dense/bias:0' shape=(256,) dtype=float32>, <tf.Variable 'dense_1/kernel:0' shape=(256, 4) dtype=float32>, <tf.Variable 'dense_1/bias:0' shape=(4,) dtype=float32>]


In [18]:
dqn.save('test3')

test3


In [20]:
del(dqn)

In [25]:
dqn.load('test')

INFO:tensorflow:Restoring parameters from test


In [48]:
x = np.array([res[0][:,:,0]])

In [84]:
benv.step(1)[0]

AttributeError: 'tuple' object has no attribute 'astype'

In [69]:
p = dqn.predict(np.array(res))

In [75]:
res2 = res/255

In [70]:
p

[array([[0.        , 0.        , 0.        , 0.10603726],
        [0.        , 0.        , 0.        , 0.10603726]], dtype=float32)]

In [54]:
np.argmax(p)

1

In [78]:
benv = gym.make('Breakout-v0')
benv.get_action_meanings()

state = np.array(benv.reset())
#state = cv2.resize(cv2.cvtColor(state, cv2.COLOR_BGR2GRAY), dsize = (84, 110))[101 - 84:-9]
print(state.shape)
res = processImg([state, state])
res = np.array([res, res])

(210, 160, 3)


In [87]:
benv.render()
done = False
state = benv.reset()
it = 0
benv.step(1)
while not done:
    action = np.random.choice([0,1,2,3])
    state, reward, done, p = benv.step(action)
    it += 1
    time.sleep(0.01)
    benv.render()
    print(action, reward)
print(it)
benv.close()

1 0.0
2 0.0
2 0.0
1 0.0
2 0.0
2 0.0
2 0.0
0 0.0
1 0.0
1 0.0
0 0.0
3 0.0
0 0.0
1 0.0
0 0.0
1 0.0
2 0.0
1 0.0
1 0.0
0 0.0
0 0.0
2 0.0
3 0.0
3 0.0
2 0.0
0 0.0
1 0.0
3 0.0
2 0.0
3 0.0
0 0.0
1 0.0
1 0.0
2 0.0
0 0.0
0 0.0
3 0.0
0 0.0
3 0.0
3 0.0
2 0.0
0 0.0
0 0.0
0 0.0
1 0.0
1 0.0
2 0.0
1 0.0
2 0.0
0 0.0
2 0.0
0 0.0
0 0.0
3 0.0
0 0.0
0 0.0
2 0.0
3 0.0
0 0.0
1 0.0
0 0.0
0 0.0
3 0.0
1 0.0
0 0.0
1 0.0
3 0.0
3 0.0
0 0.0
0 0.0
3 0.0
3 0.0
2 0.0
3 0.0
3 0.0
3 0.0
2 0.0
1 0.0
3 0.0
1 0.0
1 0.0
2 0.0
3 0.0
3 0.0
3 0.0
0 0.0
1 0.0
2 0.0
2 0.0
0 0.0
0 0.0
0 0.0
1 0.0
2 0.0
0 0.0
1 0.0
3 0.0
1 0.0
0 0.0
0 0.0
0 0.0
3 0.0
0 0.0
0 0.0
1 0.0
0 0.0
1 0.0
3 0.0
3 0.0
2 0.0
0 0.0
3 0.0
1 0.0
2 0.0
0 0.0
0 0.0
3 0.0
2 0.0
1 0.0
3 0.0
0 0.0
1 0.0
2 0.0
0 0.0
0 0.0
3 0.0
2 0.0
3 0.0
1 0.0
2 0.0
2 0.0
1 0.0
1 0.0
3 0.0
2 0.0
2 0.0
0 0.0
3 0.0
1 0.0
2 0.0
1 0.0
1 0.0
2 0.0
2 0.0
0 0.0
1 0.0
1 0.0
1 0.0
3 0.0
3 0.0
2 0.0
2 0.0
2 0.0
0 0.0
3 0.0
2 0.0
2 0.0
2 0.0
1 0.0
0 0.0
0 0.0
2 0.0
3 0.0
3 0.0
0 0.0
0 1.0
0 0.

In [None]:
benv.close()

In [None]:
data = keras.datasets.fashion_mnist

In [None]:
(train_img, train_lbl), (test_img, test_lbl) = data.load_data()

In [None]:
class_names = ['TS', "TR", "PL", 'DR', "CO", "SN", 'SH', 'SN', 'BG', 'AB']

In [None]:
plt.imshow(train_img[0], cmap = 'gray')

In [None]:
train_img = train_img/255
test_img = test_img/255

In [None]:
model = keras.Sequential([
    keras.layers.Flatten(input_shape = (28, 28)),
    keras.layers.Dense(128, activation = 'relu'),
    keras.layers.Dense(10, activation = 'softmax')
])

model.compile(optimizer = "adam", loss = "sparse_categorical_crossentropy", metrics = ["accuracy"])


In [None]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

In [None]:
model.fit(train_img, train_lbl, batch_size = 1000, epochs=5, verbose = 1)