In [None]:
from google.colab import drive
drive.mount('/content/drive')

import pandas as pd

### Transforming from JSON to CSV:
Original dataset link: https://huggingface.co/datasets/huggingface-projects/color-palettes-sd/

In [None]:
import json

json_data = json.load(open('/content/drive/MyDrive/Colab Notebooks/ColorSchemes/dataset/text-to-image-dataset.json', 'r'))
dataset = pd.DataFrame(columns={'prompt': pd.Series(dtype=str), 'imgURL': pd.Series(dtype=str), 'colors': pd.Series(dtype=object)})

for obj in json_data:
  prompt = obj['data']['prompt']
  for img in obj['data']['images']:
    colors = img['colors']
    imgURL = img['imgURL']

    dataset.loc[len(dataset.index)] = [prompt, imgURL, colors]

In [None]:
dataset.to_csv('/content/drive/MyDrive/Colab Notebooks/ColorSchemes/dataset/text-to-image-dataset-converted.csv')

### Using the new datatset

In [None]:
df = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/ColorSchemes/dataset/text-to-image-dataset-converted.csv')
df.head()

### Converting Hexcodes in RGB values

In [None]:
def hex_to_rgb(hex_list_str: str):
  rgb_list = []
  hex_list = hex_list_str.strip('[]').split(', ')
  for hex in hex_list:
    hex = hex.strip('"\'')
    rgb = []
    for i in (1, 3, 5): # hexcodes must be strictly 6 characters + 1 ('#')
      decimal = int(hex[i:i+2], 16)
      rgb.append(decimal)

    rgb_list.append(rgb)

  return list(rgb_list)

df['colors'] = df['colors'].apply(hex_to_rgb)
df = df[['colors']]
df.head()

### Prepare the dataset and split into train, test, validation

In [None]:
import torch

df.rename(columns={'colors': 'output'}, inplace=True)

def random_filter(x: list):
  rand_indices = torch.randint(low=0, high=len(x), size=(2,1), generator=torch.manual_seed(111))
  rand_indices.sort()
  out = [x[index] for index in rand_indices]
  return out

df['input'] = df['output'].apply(random_filter)
df = df[['input', 'output']]
df.head()

In [None]:
from torch.utils.data import DataLoader

# def split(full_dataset, val_percent, test_percent, random_seed=None):
#   amount = len(full_dataset)

#   test_amount = (
#     int(amount * test_percent)
#     if test_percent is not None else 0)
#   val_amount = (
#     int(amount * val_percent)
#     if val_percent is not None else 0)
#   train_amount = amount - test_amount - val_amount

#   train_dataset, val_dataset, test_dataset = random_split(
#     full_dataset,
#     (train_amount, val_amount, test_amount),
#     generator = (
#       torch.Generator().manual_seed(random_seed)
#       if random_seed
#       else None
#     )
#   )
#   return train_dataset, val_dataset, test_dataset

# train_data, val_data, test_data = split(df, 0.1, 0.2, 111)

train_data = df['output']
train_data_length = len(train_data)
print(train_data_length)
train_labels = torch.zeros(train_data_length)
train_set = [(torch.tensor(train_data[i]), train_labels[i]) for i in range(train_data_length)]

BATCH_SIZE = 32 # we will train the network in batches of data
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) # Training Data

9841


## Building GAN
Input must be two colors. Output should be 3 more colors' values.\
Suggestions from [Sandhya Krishnan (Geek Culture)](https://medium.com/geekculture/introduction-to-neural-network-2f8b8221fbd3):
- The number of hidden neurons should be between the size of the input layer and the size of the output layer.
- The number of hidden neurons should be 2/3 the size of the input layer, plus the size of the output layer.
- The number of hidden neurons should be less than twice the size of the input layer.

size of input layer = 3 + 3 = 6\
size of output layer = 3 + 3 + 3 = 9

In [None]:
from torch import nn

class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.sample_in = None
    self.sample_out = None
      # # V2:
      # nn.Linear(6, 128),
      # nn.LeakyReLU(),
      # nn.Dropout(0.1),
      # nn.Linear(128, 60),
      # nn.LeakyReLU(),
      # nn.Linear(60, 49),
      # nn.LeakyReLU(),
      # nn.Linear(49, 9),
      # nn.Sigmoid() # to contain the output in a range

      # # V3:
      # nn.Linear(6, 128),
      # nn.LeakyReLU(),
      # nn.Linear(128, 94),
      # nn.LeakyReLU(),
      # nn.Linear(94, 72),
      # nn.LeakyReLU(),
      # nn.Linear(72, 56),
      # nn.LeakyReLU(),
      # nn.Linear(56, 9),
      # nn.Sigmoid() # to contain the output in a range
    self.model = nn.Sequential(
      # V1 and V4:
      nn.Linear(6, 18),
      nn.LeakyReLU(),

      nn.Linear(18, 27),
      nn.LeakyReLU(),

      nn.Linear(27, 9),
      nn.Sigmoid() # to contain the output in a range
    )

  def forward(self, x):
    input = torch.reshape(x, shape=(x.size(dim=0), 6)) / 255 # change the range from [0-255] to [0-1]
    output = self.model(input)
    output = torch.reshape(output, shape=(x.size(dim=0), 3, 3)) * 255 # change the range from [0-1] to [0-255]
    output = torch.cat((x, output), dim = 1)
    return output

class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
      nn.Linear(15, 256),
      nn.LeakyReLU(),
      nn.Dropout(0.3),

      nn.Linear(256, 128),
      nn.LeakyReLU(),
      nn.Dropout(0.3),

      nn.Linear(128, 64),
      nn.LeakyReLU(),
      nn.Dropout(0.3),

      nn.Linear(64, 1),
      nn.Sigmoid() # 0 means input was fake and 1 means input was real
    )

  def forward(self, x):
    x = torch.reshape(x, shape=(x.size(dim=0), 15))
    output = self.model(x)
    output = torch.reshape(output, shape=(x.size(dim=0),1))
    return output

### Train the model

In [None]:
generator = Generator()
discriminator = Discriminator()

LR = 0.001
EPOCHS = 300
loss_fn = nn.BCELoss()

d_optim = torch.optim.Adam(discriminator.parameters(), lr=LR)
g_optim = torch.optim.Adam(generator.parameters(), lr=LR)
latent_space_samples = torch.rand(size=(BATCH_SIZE, 2, 3), generator=torch.manual_seed(111))*255 # fixed noise for generator

REAL_DATA_LABEL = torch.ones((BATCH_SIZE, 1)) # 1 - real data
GEN_DATA_LABEL = torch.zeros((BATCH_SIZE, 1)) # 0 - fake data
DATA_LABELS = torch.cat((REAL_DATA_LABEL, GEN_DATA_LABEL))

for epoch in range(1, EPOCHS + 1):
  for n, (real_samples, _) in enumerate(train_loader):
    generated_samples = generator(latent_space_samples)

    # concatenate all the data into a single input and target (or labels) tensor
    all_samples = torch.cat((real_samples, generated_samples))

    # Training the discriminator
    d_optim.zero_grad() # equivalent to discriminator.zero_gread()
    output_discriminator = discriminator(all_samples)

    loss_discriminator = loss_fn(output_discriminator, DATA_LABELS)
    loss_discriminator.backward()
    d_optim.step()

    # Training the generator
    g_optim.zero_grad() # equivalent to generator.zero_gread()
    generated_samples = generator(latent_space_samples)
    output_discriminator = discriminator(generated_samples)

    # Generator loss
    loss_generator = loss_fn(output_discriminator, REAL_DATA_LABEL) # Generator must produce realistic outputs
    loss_generator.backward()
    g_optim.step()

    if (epoch % 10 == 0) and (n == BATCH_SIZE - 1):
      loss_value = f"Epoch: {epoch}. Discriminator Loss: {loss_discriminator}. Generator Loss: {loss_generator}"
      print(loss_value)

In [None]:
import time
ts = str(time.time())
torch.save(generator.state_dict(), '/content/drive/MyDrive/Colab Notebooks/ColorSchemes/model/Generator.'+ts)
torch.save(discriminator.state_dict(), '/content/drive/MyDrive/Colab Notebooks/ColorSchemes/model/Discriminator.'+ts)

### Testing out the model

In [None]:
gen_model = Generator()
gen_model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/ColorSchemes/model/GeneratorV4'))
gen_model.eval()

def rgb_to_hex(r, g, b):
  return '#{:02x}{:02x}{:02x}'.format(r, g, b)

out = torch.round(gen_model(torch.tensor([
    [[245., 173., 49.], [168., 9., 40.]],
     [[58., 71., 36.], [41., 95., 105.]],
     [[80., 65., 49.], [33., 25., 21.]],
     [[178., 141., 74.], [204., 203., 160.]],
     [[26., 46., 53.], [156., 176., 115.]],
])).detach()).to(torch.int32)

for res in out:
  hex = []
  for rgb in res:
    hex.append(rgb_to_hex(*rgb))
  print(*hex, sep='\n')
  print()