Install Necessary Packages

In [None]:
# In Google Colab
!pip install diffusers transformers accelerate

Import Necessary Libraries

In [None]:
from diffusers import StableDiffusionXLPipeline
import torch

Image Generator *Unlimited Image Generations + Pretty Fast

In [None]:
#Huggin Face provided a pipeline to interact and use the image generator (in this case I am using Segmind Stable Diffusion 1B (SSD-1B) which is free and fast! Though not as great in queality as today's models.)
pipe = StableDiffusionXLPipeline.from_pretrained(
    "segmind/SSD-1B",
    torch_dtype=torch.float16
)
pipe.to("cuda")

###Approach 1: Basic Q-Learning

Custom Environment for Human-In-The-Loop Image Generation

In [None]:
import numpy as np
from tqdm import tqdm
import random
from collections import defaultdict
import pickle
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
from re import M
from IPython.display import display
#Following OpenGym's custom environment setup interface (though not using it for opengym)

class ArtCreationEnv():
  #begin with basic initialization
  def __init__(self, image_generator, obs_space, model = 'q-learn'):
    self.image_generator = image_generator
    #curr_state will be a four tuple -color, style, mood, composition
    self.curr_state = (0,) * len(obs_space)
    self.obs_space = obs_space
    self.model = model
    self.step_count = 0
    self.max_iters = 10
    if model == 'dqn' or model == 'policy-gradient':
      self.rating_history = []
      self.action_history = []

  def reset(self):
    self.curr_state = (0,) * len(self.obs_space)
    self.step_count = 0
    info = {}
    if self.model == 'q-learn':
      return self.curr_state

    if self.model == 'dqn' or self.model == 'policy-gradient':
      self.rating_history = []
      self.action_history = []

      if self.model == 'policy-gradient':
        state = self.curr_state
      else:
        state = self.state_dqn()

      return state, info

  def state_dqn(self):
    state = np.zeros(128)
    info = {}
    dims = ['color', 'style', 'mood', 'composition']
    features = []
    for dim in range(len(self.curr_state)):
      features.append(self.curr_state[dim]/len(self.obs_space[dims[dim]]))

    if len(self.rating_history) > 0:
      features.append(np.mean(self.rating_history[-5:]))
      features.append(np.max(self.rating_history))
      features.append(np.min(self.rating_history))
      features.append(self.rating_history[-1])
    else:
      features.extend([0,0,0,0])



    colors = np.zeros(len(self.obs_space['color']))
    styles = np.zeros(len(self.obs_space['style']))
    moods = np.zeros(len(self.obs_space['mood']))
    compositions = np.zeros(len(self.obs_space['composition']))


    for action in self.action_history[-10:]:
      colors[action['color']] += 1
      styles[action['style']] += 1
      moods[action['mood']] += 1
      compositions[action['composition']] += 1

    if len(self.action_history) > 0:
      colors = colors/len(self.action_history[-10:])
      styles = styles/len(self.action_history[-10:])
      moods = moods/len(self.action_history[-10:])
      compositions = compositions/len(self.action_history[-10:])

    features.extend(colors[:5])
    features.extend(styles[:5])
    features.extend(moods[:5])
    features.extend(compositions[:5])

    features = np.array(features, dtype = np.float32)

    features = np.pad(features, 128 - len(features)) if len(features) < 128 else features[:128]

    return features

  def apply_action(self, action):
    # {color: some index, style: some index, mood: some index, composition: some index}

    color = self.obs_space['color'][action[0]]
    style = self.obs_space['style'][action[1]]
    mood = self.obs_space['mood'][action[2]]
    composition = self.obs_space['composition'][action[3]]


    new_prompt = f"a beautiful landscape, {color} color palette, {style} style, {mood} mood, {composition} composition"
    return new_prompt

  #again action is a four tuple - color, style, mood, composition
  def step(self, action, agent):
    self.step_count += 1

    if self.model == 'policy-gradient': # policy gradient agents directly returns to next state tuple
      next_state = action

    elif self.model == 'dqn':
      self.curr_state = (
        action['color'],
        action['style'],
        action['mood'],
        action['composition']
        )
      next_state = self.curr_state
    # existing Q-learn logic
    else:
      next_state = agent.apply_action(self.curr_state, action)


    prompt = self.apply_action(next_state)
    image = self.generate_image(prompt)
    reward = self.get_rating(image, prompt)
    if self.model == 'dqn':
      self.rating_history.append(reward)
      self.action_history.append(action)

    terminated, truncated = False, False

    if self.step_count >= self.max_iters:
      truncated = True

    # termination logic for dqn
    if self.model == 'dqn' and len(self.action_history) > 3:
      if all(r>.5 for r in self.rating_history[-3:]):
        terminated = True

    self.curr_state = next_state
    info = {'image':image, 'prompt':prompt}

    if self.model == 'q-learn':
      return next_state, reward, info
    if self.model == 'policy-gradient':
      return next_state, reward, info, terminated, truncated
    # existing dqn logic
    else:
      state = self.state_dqn()
      return state, reward, info, terminated, truncated

  #simply have user prompt and the agent will slowly learn the artist's style
  #num of inferference steps set low for efficiency but real implementation should
  #use advanced ai image generators anyways so this is for training purposes!
  def generate_image(self, prompt):
    image = self.image_generator(
        prompt,
        num_inference_steps=20
    ).images[0]
    return image

  #the reward is from the artist's rating with 5 = higher and 1 = low reward
  def get_rating(self, image, prompt):
    display(image)
    print(f"Prompt: {prompt}")
    print("Rate this image:")
    rating = int(input("Rate 1-5:"))
    #squish the values between -1 and 1 so range is like 5:1 4:.5 3:0 2:-.5 1:-1 so 1 penalizes while 5 gives greatest reward
    return (rating - 3)/2


In [None]:
class HumanInLoopAgent():
  def __init__(self, learning_rate, initial_epsilon, epsilon_decay, final_epsilon, discount_factor, q_table = False):
    self.lr = learning_rate
    self.epsilon = initial_epsilon
    self.epsilon_decay = epsilon_decay
    self.final_epsilon = final_epsilon
    self.discount_factor = discount_factor
    #Q-table will be Q[state][action] = value (float reward)
    self.q_table = q_table if q_table else defaultdict(lambda: defaultdict(float))
    self.dim_space = 5
    self.dims = ['color', 'style', 'mood', 'composition']

  def action_space(self):
    actions = []
    for dim in self.dims:
      for val in range(self.dim_space):
        actions.append((dim, val))
    return actions

  #state here is a four tuple (color, style, mood, comp) and now choose either opt or explore new from all actions
  def choose_action(self, state):
    #either explore or exploit
    if np.random.random() < self.epsilon:
      dim = random.choice(self.dims)
      val = random.randint(0,self.dim_space-1)
      action = (dim, val)
    else:
      actions = self.action_space()
      q_values = [self.q_table[state][action] for action in actions]
      max_q = max(q_values)
      best_actions = []
      for action, q in zip(actions, q_values):
        if q == max_q:
          best_actions.append(action)
      action = random.choice(best_actions)
    return action

  def decay_epsilon(self):
    self.epsilon = max(self.final_epsilon, self.epsilon * self.epsilon_decay)

  def update(self, state, action, reward, next_state):
    '''
    Q(s,a) ← Q(s,a) + α[r + γ max Q(s',a') - Q(s,a)]
    '''
    current_q = self.q_table[state][action]
    actions = self.action_space()
    q_values = [self.q_table[next_state][action] for action in actions]
    max_next_q = max(q_values) if q_values else 0
    new_q = current_q + self.lr * (reward + self.discount_factor * max_next_q - current_q)
    self.q_table[state][action] = new_q

  def apply_action(self, curr_state, best_action):
    #here best_action is like ('color',3)
    #curr_state = (2,3,3,3)
    dim, val = best_action
    new_state = list(curr_state)
    new_state[self.dims.index(dim)] = val
    return tuple(new_state)

  def get_best_action(self, curr_state):
    actions = self.action_space()
    q_values = [self.q_table[curr_state][action] for action in actions]
    best_action = actions[np.argmax(q_values)]
    return self.apply_action(curr_state, best_action)

  def save_style(self, filename):
    with open(filename, 'wb') as file:
      data = {
          'q_table':dict(self.q_table),
          'lr':self.lr,
          'epsilon':self.epsilon,
          'epsilon_decay':self.epsilon_decay,
          'final_epsilon':self.final_epsilon,
          'discount_factor': self.discount_factor
      }
      pickle.dump(data, file)

  def load_style(self, filename):
    if not os.path.exists(filename):
      return None
    with open(filename, 'rb') as file:
      data = pickle.load(file)
      q_table = defaultdict(lambda: defaultdict(float), data['q_table'])
      agent = HumanInLoopAgent(data['lr'], data['epsilon'], data['epsilon_decay'], data['final_epsilon'], data['discount_factor'], q_table)
      return agent

  def generate_my_image(prompt, env):
    best_state = agent.get_best_action(env.curr_state)
    print(f"Best discovered settings: {best_state}")
    prompt = env.apply_action(best_state, testing = True)
    print(f"\nGenerating image with: {prompt}")
    final_image = env.generate_image(prompt)
    display(final_image)





Train the Agent

In [None]:
learning_rate = 0.2
start_epsilon = .8
epsilon_decay = .88
final_epsilon = 0.1
discount_factor = .95
n_iters = 20
env = ArtCreationEnv(pipe, {
            'color': ['warm', 'cool', 'vibrant', 'muted', 'monochrome'],
            'style': ['realistic', 'painterly', 'abstract', 'minimalist', 'surreal'],
            'mood': ['peaceful', 'energetic', 'mysterious', 'joyful', 'melancholic'],
            'composition': ['centered', 'rule-of-thirds', 'asymmetric', 'symmetrical', 'dynamic']
        })

agent = HumanInLoopAgent(learning_rate, start_epsilon, epsilon_decay, final_epsilon, discount_factor)

for i in tqdm(range(n_iters)):
  state = env.curr_state if i > 0 else env.reset()
  action = agent.choose_action(state)
  next_state, reward,info = env.step(action, agent)

  agent.update(state, action, reward, next_state)
  agent.decay_epsilon()

  print(f"Reward: {reward:.2f}")
  print(f"Current best settings: {agent.get_best_action(next_state)}")

best_state = agent.get_best_action(env.curr_state)
print(f"Best discovered settings: {best_state}")

# Generate final image with best settings
final_prompt = env.apply_action(best_state)
print(f"\nGenerating final image with: {final_prompt}")
final_image = env.generate_image(final_prompt)
display(final_image)


Save the model and Load and Test It!

In [None]:
#learned now can use
best_state = agent.get_best_action(env.curr_state)
print(f"Best discovered settings: {best_state}")
prompt = env.apply_action(best_state)
print(f"\nGenerating image with: {prompt}")
final_image = env.generate_image(prompt)
display(final_image)

Our Q-table is way too big! 5 * 5 * 5* 5 = 625 States and each state has 20 different possible actions = 12,500 entries. Let's pivot to Deep Q-learning!

Now, expand the artist choices(feedback) and because we're working with a high dimensional space now let's pivot to using a new approach of policy gradient.

###Let's expand the states space and for this. We'll implement Deep Q-Learning


Sources used: https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

In [None]:
from collections import namedtuple, deque
import math
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)
from itertools import count

In [None]:
##Replay memory -> decorrelates experiences by randomly sampling which improves generalization

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

#named tuple is part of python library and allows me to access Transition via Transition.state, Transition.action, Transition.next_state, etc
class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
#create a neural network - where it chooses from each category hence four separate action spaces
class DQN(nn.Module):
    def __init__(self, n_observations):
        super().__init__()

        self.start = nn.Sequential(
            nn.Linear(n_observations, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )

        self.color = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 22)
        )

        self.style = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 18)
        )

        self.mood = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 16)
        )

        self.composition = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 15)
        )

    def forward(self, x):
        start = self.start(x)
        return {
            'color': self.color(start),
            'style': self.style(start),
            'mood': self.mood(start),
            'composition': self.composition(start)
        }

In [None]:
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer

BATCH_SIZE = 64
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 2500
#note how the update rate is small this is to ensure it is as stable as possible
TAU = 0.005
LR = 3e-4
steps_done = 0


def select_action(state):
  global steps_done
  sample = random.random()
  eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
  steps_done += 1
  if sample > eps_threshold:
      #exploit by taking max Q-value for each of the 4 categories
    with torch.no_grad():
      q_values = policy_net(state)
      action = {
            'color': torch.argmax(q_values['color']).item(),
            'style': torch.argmax(q_values['style']).item(),
            'mood': torch.argmax(q_values['mood']).item(),
            'composition': torch.argmax(q_values['composition']).item()
          }
  else:
    #explore
    action = {
                'color': random.randint(0,21),
                'style': random.randint(0,17),
                'mood': random.randint(0,15),
                'composition': random.randint(0,14)
            }
  return action

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.

    batch = Transition(*zip(*transitions))
    #Now can call
    #batch.state
    #batch.action
    #batch.next_state


    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    #if next_state is None turn to False
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    #unsqueeze adds another dimension for the next step when we perform .gather operation
    color_batch = torch.tensor([action['color'] for action in batch.action], device = device).unsqueeze(1)
    style_batch = torch.tensor([action['style'] for action in batch.action], device = device).unsqueeze(1)
    mood_batch = torch.tensor([action['mood'] for action in batch.action], device = device).unsqueeze(1)
    composition_batch = torch.tensor([action['composition'] for action in batch.action],device = device).unsqueeze(1)

    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    #Get q_values
    q_values = policy_net(state_batch) #.gather(1, action_batch)
    color_q_values = q_values['color'].gather(1, color_batch)
    style_q_values = q_values['style'].gather(1, style_batch)
    mood_q_values = q_values['mood'].gather(1, mood_batch)
    composition_q_values = q_values['composition'].gather(1, composition_batch)
    state_action_values = (color_q_values + style_q_values + mood_q_values + composition_q_values)/4

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1).values
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)

    #get next state q values from target net
    with torch.no_grad():
        next_q = target_net(non_final_next_states)
        next_color = next_q['color'].max(1).values
        next_style = next_q['style'].max(1).values
        next_mood = next_q['mood'].max(1).values
        next_composition = next_q['composition'].max(1).values
        next_state_values[non_final_mask] = (next_color + next_style + next_mood + next_composition)/4
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()



In [None]:
###Training
#modified observation space
obs_space = {
            'color': ['warm', 'cool', 'neutral', 'dark', 'light', 'vibrant', 'muted', 'monochrome', 'analogous', 'complementary', 'pastel', 'earth-tones','neon','modern','bohemian','spring','summer','winter','autumn','gradient','impressionist','expressionist'],
            'style': ["photorealistic", "hyperrealistic", "realistic","impressionist", "expressionist", "abstract_expressionist","minimalist", "geometric", "abstract","surreal", "dreamlike", "ethereal","oil_painting", "watercolor", "pencil_sketch",
                      "digital_art", "concept_art", "3d_render"],
            'mood': ["calm", "peaceful", "energetic", "chaotic","joyful", "melancholic", "mysterious", "ominous","bright", "dark", "moody", "ethereal", "dramatic","warm", "cool", "neutral"],
            'composition': ["centered", "symmetrical", "asymmetrical","rule_of_thirds", "golden_ratio", "diagonal","shallow_depth", "deep_depth", "layered","close_up", "wide_shot", "portrait", "landscape","static", "dynamic", "flowing"]
}
env = ArtCreationEnv(pipe, obs_space, model = 'dqn')
state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations).to(device)
target_net = DQN(n_observations).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)
n_iters = 20
steps_done = 0

for i_episode in tqdm(range(n_iters)):
    # Initialize the environment and get its state
    state, info = env.reset()
    print(state)
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    episode_reward = 0
    for t in count():
        action = select_action(state)
        print("here",action)
        observation, reward, info, terminated, truncated = env.step(action, None)
        reward_value = reward
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated
        episode_reward += reward_value
        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            print(f"Episode {i_episode+1} finished after {t+1} steps, Total Reward: {episode_reward:.2f}")
            break


###Policy Gradient
Input: state (4-dims or 20-d one-hot)
-->
Shared hidden layers
-->
[color head (5 logits), style head (5 logits), mood head (5 logits), composition head (5 logits)]
*   each head --> softmax



Import Necessary libraries


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [None]:
# PolicyNetwork -- the Neural Network Architecture
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim=4):
      super(PolicyNetwork, self).__init__()

      # shared feature extractor
      self.shared = nn.Sequential(
          nn.Linear(input_dim, 64),
          nn.ReLU(),
          nn.Linear(64, 64),
          nn.ReLU()
      )

      # separate output heads for each output dimension
      self.color_head = nn.Linear(64, 5)
      self.style_head = nn.Linear(64, 5)
      self.mood_head = nn.Linear(64, 5)
      self.comp_head = nn.Linear(64, 5)

    def forward(self, state):
      # state: tensor of shape (4,)
      # returns: list of logits (not probabilities yet)

      # convert state tuple/list to Float Tensor if needed
      if not isinstance(state, torch.Tensor):
        x = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(next(self.parameters()).device)
      else:
        x = state

      shared_features = self.shared(x)

      return {"color": self.color_head(shared_features),
              "style": self.style_head(shared_features),
              "mood": self.mood_head(shared_features),
              "composition": self.comp_head(shared_features)}

    def select_action(self, state):
      # get logits for the 4 dimensions
      logits_dict = self.forward(state)

      log_prob_sum = torch.tensor(0.0).to(next(self.parameters()).device)
      new_state_list = []

      for dim, logits in logits_dict.items():
        # create categorical distribution for each logit
        dist = Categorical(logits=logits) # softmax is implicit
        # sample an index (0 - 4)
        action = dist.sample()

        log_prob = dist.log_prob(action).squeeze()
        log_prob_sum += log_prob
        new_state_list.append(action.item())

      # we want to return a 4-d new state (a tuple of 4 indices)
      new_state = tuple(new_state_list)

      # log_prob_sum is a tensor; squeeze is used to get a single value
      return new_state, log_prob_sum


In [None]:
# Policy Gradient Agent Class
class PolicyGradientAgent():
  def __init__(self, learning_rate, discount_factor):
    self.policy_net = PolicyNetwork(input_dim=4).to(device) # will use PolicyNetwork defined above
    self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
    self.discount_factor = discount_factor
    self.episode_data = [] # stores (log_prob, reward) for one episode

  def choose_action(self, state):
    new_state, log_prob = self.policy_net.select_action(state)
    self.episode_data.append((log_prob, None)) # store the log_prob and no reward for now
    return new_state

  # adds the reward to the most recent entry in episode buffer
  def store_reward(self, reward):
    if self.episode_data:
      last_entry = self.episode_data[-1]
      self.episode_data[-1] = (last_entry[0], reward)

  def get_best_state(self, curr_state):
    # convert state to tensor if needed
    if not isinstance(curr_state, torch.Tensor):
          curr_state = torch.tensor(curr_state, dtype=torch.float32).unsqueeze(0).to(device)

    with torch.no_grad():
      logits_dict = self.policy_net.forward(curr_state)

      # for each head, find index with maximum logit (highest probability)
      best_color_idx = torch.argmax(logits_dict['color']).item()
      best_style_idx = torch.argmax(logits_dict['style']).item()
      best_mood_idx = torch.argmax(logits_dict['mood']).item()
      best_comp_idx = torch.argmax(logits_dict['composition']).item()

      return (best_color_idx, best_style_idx, best_mood_idx, best_comp_idx)


In [None]:
# Reinforce update function
def update_policy(agent):
  # Loss = - (log_prob * discounted_return)

  if not agent.episode_data:
    return

  # prepare data
  log_probs = [data[0] for data in agent.episode_data]
  rewards = [data[1] for data in agent.episode_data]

  # calculate discounted returns (G_t)
  R = 0
  returns = []

  for r in rewards[::-1]: # iterate through episode backwards
    R = r + agent.discount_factor * R
    returns.insert(0, R)

  returns = torch.tensor(returns, dtype=torch.float32).to(device)
  returns = (returns - returns.mean()) / (returns.std() + 1e-9) # normalize returns

  # calculate the policy loss
  policy_loss = []
  for log_prob, R in zip(log_probs, returns):
    policy_loss.append(-log_prob * R)

  # optimization
  agent.optimizer.zero_grad()
  policy_loss = torch.stack(policy_loss).sum()  # summing losses from the episode
  policy_loss.backward()
  agent.optimizer.step()

  agent.episode_data = [] # reset episode data

In [None]:
# Training the Policy Gradient Agent

# set up hyperparameters
LEARNING_RATE = 5e-4
GAMMA = 0.99
NUM_EPISODES = 5
MAX_ITERATIONS_PER_EPISODE = 20
#modified observation space
obs_space_pg = {
    'color': ['warm', 'cool', 'vibrant', 'muted', 'monochrome'],
    'style': ['realistic', 'painterly', 'abstract', 'minimalist', 'surreal'],
    'mood': ['peaceful', 'energetic', 'mysterious', 'joyful', 'melancholic'],
    'composition': ['centered', 'rule-of-thirds', 'asymmetric', 'symmetrical', 'dynamic']
}
env = ArtCreationEnv(pipe, obs_space_pg, model='policy-gradient')
agent = PolicyGradientAgent(learning_rate=LEARNING_RATE, discount_factor=GAMMA)

episode_rewards = []

print(f"Starting Policy Gradient (REINFORCE) Training for {NUM_EPISODES} episodes...")

for i_episode in tqdm(range(NUM_EPISODES)):
  state, info = env.reset()
  Done = False
  episode_reward_total = 0

  for t in range(MAX_ITERATIONS_PER_EPISODE):
    new_state = agent.choose_action(state)
    next_state, reward, info, terminated, truncated = env.step(new_state, agent)
    agent.store_reward(reward)  # store reward the user provided

    # updating the state
    state = next_state
    episode_reward_total += reward

    if terminated or truncated:
      break

  # episode end -- perform policy update
  update_policy(agent)

  episode_rewards.append(episode_reward_total)
  print(f"Episode {i_episode+1} finished. Total Reward: {episode_reward_total:.2f}. Last State: {state}")

print("Training Complete.")
final_state = agent.get_best_state(env.curr_state)
print(f"Best learned state indices: {final_state}")
final_prompt = env.apply_action(final_state)

# generating image with best settings
print(f"\nGenerating final image with: {final_prompt}")
final_image = env.generate_image(final_prompt)
display(final_image)