In [28]:
''' for data generation '''

import gym
import gym.spaces 
import cv2
import numpy as np 
import random
import torch 


''' model '''
import torch.nn as nn 
import torch.optim as optim 

''' summary writer '''
from tensorboardX import SummaryWriter

''' imagee '''
import torchvision.utils as vutils 

**DataSource: atari Environment** 
***
Changes we need to make: 
     - atari Image resolution: 210 * 160  => 64 * 64 
     - atari image format: channel last => pytorch channel type channel first 
     - data type, image has uint8 => calculation needs float32 

***
how to make these changes: 

There are two ways for it: 

    First - making changes directly to gym env by using InputWrapper
        - create a Input wrapper inheriting the gym observation wrapper
        - change the observation space to Box => but why? I don't get it. 
        
    Second: get the observation from gym, postprocess it accordingly. 
        - get obs using env.step
        - define a function to modify it. 
        
    

In [5]:
''' constants we need '''
Image_size = 64 # output image size for our GAN 
batch_size = 16  # batch size to generate from env 


# saving env images to disk 
saved_index = 0 
max_save = 100 
save = False 

##  method 1: inputwrapper 

In [21]:
class InputWrapper(gym.ObservationWrapper):
    def __init__(self, *args):
        super(InputWrapper, self).__init__(*args)
        assert isinstance(self.observation_space, gym.spaces.Box)
        old_space = self.observation_space
        self.observation_space = gym.spaces.Box(self.observation(old_space.low), self.observation(old_space.high),
                                                dtype=np.float32)
           
    def observation(self, observation):
        global save 
        new_obs = cv2.resize(observation, (Image_size, Image_size)) 
        
        if save and np.mean(new_obs) > 0.01:
            self.save_images(new_obs)
            
        new_obs = np.moveaxis(a = new_obs, source= 2,destination= 0)
        new_obs = new_obs.astype(np.float32) 
        return new_obs    
    
    def save_images(self, obs):
        global saved_index , max_save 
        if saved_index < max_save :
            cv2.imwrite( './atari saved images/wrapper_method/img' + str(saved_index) + '.png', np.uint8(obs))
        saved_index += 1  

In [23]:
def iterate_batches(envs):
    
    global saved_index
    initial_images_of_env = [e.reset() for e in envs] 
    
    batch = []
    
    # select a random environment from envs 
    env_gen = iter(lambda: random.choice(envs), None)

    while True:
        e = next(env_gen)
        obs, reward, is_done, _ = e.step(e.action_space.sample())

        if np.mean(obs) > 0.01:
            batch.append(obs)
        
        if len(batch) == batch_size:
            batch_np = np.asarray(batch, np.float32) * 2 / 255.0 - 1
            yield torch.tensor(batch_np)
            batch.clear()

        if is_done:
            e.reset()  
            
            
  

      
# env_names = ['Breakout-v0', 'AirRaid-v0', 'Pong-v0']

# envs = [InputWrapper(gym.make(name)) for name in env_names] 
# for e in envs:
#     print(e.observation_space.shape)
    
# x_max = 1
# x = 0 

# for batch_v in iterate_batches(envs):
#     if x < x_max:
#         x+= 1 
#         print(batch_v.size())
#         continue 
#     else:
#         break 





## method 2: define these operation outside of environment

In [2]:
# ''' constants we need '''

# Image_size = 64 # output image size for our GAN 
# batch_size = 16  # batch size to generate from env 

# # saving env images to disk 
# save = True 
# saved_index = 0 
# max_save = 100 


# def save_image(obs):
#     global saved_index 
#     if saved_index < max_save:
#         cv2.imwrite(
#             './atari saved images/non_wrapper_method/img' + str(saved_index) + '.png',
#             np.uint8(obs))
#         saved_index += 1


# def preprocess(obs):
#     obs = cv2.resize(obs, (Image_size, Image_size))
#     if save and saved_index < max_save:
#         save_image(obs) 
        
#     obs = np.moveaxis(a=obs, source=2, destination=0)
#     obs = obs.astype(np.float32)
#     return obs


# def iterate_batches(envs):
#     global saved_index, save, batch_size
    
#     [e.reset() for e in envs] 
    
#     batch = []
#     env_gen = iter(lambda: random.choice(envs), None)

#     while True:
#         e = next(env_gen)
#         obs, reward, is_done, _ = e.step(e.action_space.sample())
        
#         # check for non-zero mean of image, due to bug in one of the games to prevent flickering of images
        
#         if np.mean(obs) > 0.01:
#             obs = preprocess(obs)
#             batch.append(obs)
        
#         if len(batch) == batch_size:
#             batch_np = np.asarray(batch, np.float32) * 2 / 255.0 - 1 # domain to -1 to 1 
#             yield torch.tensor(batch_np)
#             batch.clear()

#         if is_done:
#             e.reset() 

In [3]:
# env_names = ['Breakout-v0', 'AirRaid-v0', 'Pong-v0']

# envs = [gym.make(name) for name in env_names] 

# x_max = 2 
# x = 0 

# for batch_v in iterate_batches(envs):
#     if x < x_max:
#         x+= 1 
#         print(batch_v.size())
#         continue 
#     else:
#         break 
    

# model 

In [9]:
''' Discriminator constants '''
DISC_FILTERS = 64 
input_channels =  3 

In [10]:
class Discriminator(nn.Module):
    
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        
        self.conv_pipe = nn.Sequential(
            # 64 -> 32 
            nn.Conv2d(in_channels= input_channels, out_channels= DISC_FILTERS,kernel_size= 4, stride= 2, padding= 1 ),
            nn.ReLU(),
            
            #32 -> 16
            nn.Conv2d(in_channels= DISC_FILTERS, out_channels= DISC_FILTERS*2, kernel_size= 4, stride = 2, padding= 1),
            nn.BatchNorm2d(DISC_FILTERS*2),
            nn.ReLU(),
            
            #16->8
            nn.Conv2d(in_channels= DISC_FILTERS*2, out_channels= DISC_FILTERS*4, kernel_size=4, stride= 2, padding=1 ),
            nn.BatchNorm2d(DISC_FILTERS*4),
            nn.ReLU(),
            
            #8->4
            nn.Conv2d(in_channels= DISC_FILTERS*4, out_channels= DISC_FILTERS*8, kernel_size= 4, stride= 2, padding = 1),
            nn.BatchNorm2d(DISC_FILTERS*8),
            nn.ReLU(),
            
            #4->1 
            nn.Conv2d(in_channels= DISC_FILTERS*8, out_channels= 1, kernel_size= 4, stride= 1, padding= 0),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        out = self.conv_pipe(x)
        
        #reshape 
        out = out.view(-1, 1).squeeze(dim = 1) 
        return out 
        

In [None]:
# '''test your discriminator '''
# disc = Discriminator(input_channels)
# test_output = disc(batch_v)
# print(test_output) 

In [11]:
''' generator constants''' 
out_channels = 3 
generator_filters = 64
latent_vector_size = 100 

In [12]:
class Generator(nn.Module):
    
    def __init__(self, out_channels):
        super(Generator, self).__init__()
        
        self.deconvpipe = nn.Sequential(
            # 4*4
            nn.ConvTranspose2d(in_channels= latent_vector_size, out_channels = generator_filters*8,kernel_size= 4, stride= 1, padding= 0),
            nn.BatchNorm2d(generator_filters*8),
            nn.ReLU(),
            
            # 8*8
            nn.ConvTranspose2d(in_channels= generator_filters*8, out_channels = generator_filters*4,kernel_size= 4, stride= 2, padding= 1),
            nn.BatchNorm2d(generator_filters*4),
            nn.ReLU(),
            
            # 16*16
            nn.ConvTranspose2d(in_channels= generator_filters*4, out_channels = generator_filters*2, kernel_size= 4, stride= 2, padding=1),
            nn.BatchNorm2d(generator_filters*2),
            nn.ReLU(),
            
            # 32*32 
            nn.ConvTranspose2d(in_channels= generator_filters*2, out_channels = generator_filters, kernel_size=4, stride= 2, padding= 1),
            nn.BatchNorm2d(generator_filters),
            nn.ReLU(),
            
            # 64*64
            nn.ConvTranspose2d(in_channels= generator_filters, out_channels = out_channels, kernel_size = 4, stride= 2, padding= 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        out = self.deconvpipe(x)
        return out 
        

In [13]:
# gen = Generator(out_channels)
# test_in = torch.FloatTensor(1, latent_vector_size, 1, 1).normal_(0,1) 
# test_out = gen(test_in)
# print(gen)

# print(test_out.shape)

In [16]:
''' main script '''
device = "cuda" if torch.cuda.is_available() else "cpu"
print("used device: ", device)

gen = Generator(out_channels).to(device)
disc = Discriminator(input_channels).to(device)

print(gen)
print(disc)

used device:  cuda
Generator(
  (deconvpipe): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): Tanh()
  )
)
Discriminator(
  (conv_pipe): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4)

In [24]:
env_names = ['Breakout-v0', 'AirRaid-v0', 'Pong-v0']
envs = [InputWrapper(gym.make(name)) for name in env_names] 

print("input shape: ", envs[0].observation_space.shape)


objective = nn.BCELoss()
gopt = optim.Adam(params= gen.parameters(), lr= 0.0001, betas= (0.5, 0.999))
dopt = optim.Adam(params= disc.parameters(), lr = 0.0001, betas= (0.5, 0.999))




input shape:  (3, 64, 64)


In [25]:
log = gym.logger
log.set_level(gym.logger.INFO)

In [70]:
''' train script '''
writer = SummaryWriter()

train_iter = 0 
max_iter = 20000

report_every = 100
save_image_every_iter = 1000 


true_labels = torch.ones(batch_size, dtype = torch.float32, device = device)
fake_labels = torch.zeros(batch_size, dtype = torch.float32, device = device)

disc_losses = []
gen_losses = []

for batch_v in iterate_batches(envs):
    
    ######################## train discriminator ############################################
    ## zero grad
    dopt.zero_grad()
    
    ## prepare the inputs
    gen_input = torch.FloatTensor(batch_size, latent_vector_size, 1,1).normal_(0,1).to(device)
    batch_v = batch_v.to(device)
    
    ## forward the models 
    gen_output = gen(gen_input)
    disc_output_on_real = disc(batch_v) 
    disc_output_on_fake = disc(gen_output.detach()) # we need only to train the disc so detach gen
    
    ## calculate loss 
    disc_loss = objective(disc_output_on_real, true_labels) + objective(disc_output_on_fake, fake_labels)
    disc_losses.append(disc_loss.item())
    
    ## get gradients
    disc_loss.backward()
    ## optizer step
    dopt.step() 
    
    
    ######################## train generator #################################################
    ## zero grad 
    gopt.zero_grad()
    
    ## forward the model
    disc_output_g = disc(gen_output)
    
    ## calcualte loss 
    gen_loss = objective(disc_output_g, true_labels) # the output should be considered as real, if not,it's a loss 
    gen_losses.append(gen_loss.item())
    
    ## calculate gradients
    gen_loss.backward()
    
    ## optimizer step
    gopt.step()
    
    
    ################## summary writer ##########################################################
    train_iter += 1 
    
    if train_iter %report_every == 0:
        log.info("Iter %d: gen_loss=%.3e, dis_loss=%.3e", train_iter, np.mean(gen_losses), np.mean(disc_losses))
        writer.add_scalar("gen_loss", np.mean(gen_losses), train_iter)
        writer.add_scalar("disc_loss", np.mean(disc_losses), train_iter)
        gen_losses.clear()
        disc_losses.clear()
        
        
    if train_iter % save_image_every_iter == 0:
        writer.add_image("fake",vutils.make_grid(gen_output.data[:64], normalize= True), train_iter )
        writer.add_image("real", vutils.make_grid(batch_v.data[:64], normalize= True), train_iter)
    
    if train_iter> max_iter:
        break 
    
writer.close()

INFO: Iter 100: gen_loss=5.884e+00, dis_loss=4.332e-02
INFO: Iter 200: gen_loss=6.145e+00, dis_loss=1.291e-01
INFO: Iter 300: gen_loss=6.171e+00, dis_loss=6.188e-02
INFO: Iter 400: gen_loss=8.798e+00, dis_loss=1.665e-01
INFO: Iter 500: gen_loss=7.023e+00, dis_loss=1.476e-01
INFO: Iter 600: gen_loss=7.588e+00, dis_loss=9.536e-02
INFO: Iter 700: gen_loss=8.383e+00, dis_loss=7.884e-02
INFO: Iter 800: gen_loss=6.294e+00, dis_loss=1.509e-01
INFO: Iter 900: gen_loss=7.889e+00, dis_loss=9.599e-02
INFO: Iter 1000: gen_loss=8.300e+00, dis_loss=2.501e-02
INFO: Iter 1100: gen_loss=8.669e+00, dis_loss=1.303e-01
INFO: Iter 1200: gen_loss=7.521e+00, dis_loss=1.525e-02
INFO: Iter 1300: gen_loss=7.671e+00, dis_loss=2.952e-02
INFO: Iter 1400: gen_loss=9.336e+00, dis_loss=7.092e-02
INFO: Iter 1500: gen_loss=8.557e+00, dis_loss=2.447e-01
INFO: Iter 1600: gen_loss=8.790e+00, dis_loss=2.840e-02
INFO: Iter 1700: gen_loss=8.849e+00, dis_loss=3.918e-02
INFO: Iter 1800: gen_loss=6.843e+00, dis_loss=5.281e-02
I

INFO: Iter 14700: gen_loss=7.546e+00, dis_loss=2.667e-02
INFO: Iter 14800: gen_loss=8.177e+00, dis_loss=3.690e-02
INFO: Iter 14900: gen_loss=8.525e+00, dis_loss=9.347e-02
INFO: Iter 15000: gen_loss=9.105e+00, dis_loss=1.056e-01
INFO: Iter 15100: gen_loss=7.645e+00, dis_loss=1.003e-01
INFO: Iter 15200: gen_loss=7.470e+00, dis_loss=6.760e-02
INFO: Iter 15300: gen_loss=6.250e+00, dis_loss=2.679e-02
INFO: Iter 15400: gen_loss=8.387e+00, dis_loss=6.504e-02
INFO: Iter 15500: gen_loss=9.920e+00, dis_loss=8.563e-02
INFO: Iter 15600: gen_loss=7.739e+00, dis_loss=1.464e-01
INFO: Iter 15700: gen_loss=8.058e+00, dis_loss=1.014e-01
INFO: Iter 15800: gen_loss=8.335e+00, dis_loss=1.140e-01
INFO: Iter 15900: gen_loss=8.492e+00, dis_loss=6.459e-02
INFO: Iter 16000: gen_loss=7.400e+00, dis_loss=2.310e-02
INFO: Iter 16100: gen_loss=7.492e+00, dis_loss=3.777e-02
INFO: Iter 16200: gen_loss=9.275e+00, dis_loss=2.565e-02
INFO: Iter 16300: gen_loss=8.224e+00, dis_loss=1.788e-02
INFO: Iter 16400: gen_loss=8.79

In [73]:

def generate_images(n):
    gen.eval()
    gen_random = torch.FloatTensor(n,latent_vector_size, 1, 1).normal_(0,1).to(device)
    images = gen(gen_random)
    
    images = (images + 1)*255.0/2 
    
    images = images.to('cpu').detach().numpy()
    images = np.moveaxis(images, 1, 3)
    print("shape of data: ", images.shape, " type ", type(images))
    
    return np.uint8(images) 

In [76]:
images = generate_images(100)
for i in range(images.shape[0]):
    cv2.imwrite('./atari saved images/GAN_generated_images/img'+str(i)+".png", images[i])

shape of data:  (100, 64, 64, 3)  type  <class 'numpy.ndarray'>


In [71]:
torch.save(gen.state_dict(), './saved_models/generator')

In [72]:
torch.save(disc.state_dict(), './saved_models/discriminator')