In [None]:
!pip install -U scikit-learn

In [None]:
import random
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

random.seed(0)

In [None]:
category_names = ['Bottle', 'Pen',
                  'Clothing', 'Drink', 'Footwear']  # using this just for reference.

In [None]:
class Category:
    """
    category class used for data gen 
    name = class name
    max_discount = float, max amount of discount an item can have.
    """
    def __init__(self, name, max_stock, discount_level):
        self.name = name
        self.max_stock = max_stock
        self.max_discount = discount_level

In [None]:
def choose_discount(sample, max_, max_disc):
    """
    chooses discount based on % of stock
    """
    pct_stock = sample / max_
    chosen_disc = round(pct_stock*max_disc, 2)
    return chosen_disc

In [None]:
def generate_data_single_cat(cat, num_samples):
    """ 
    function to generate synthetic data points for 1 category
    takes in a category object and number of points to generate.
    """
    category_data = []
    for _ in range(num_samples):
        sample_stock = random.randint(0, cat.max_stock)
        pct_stock = sample_stock / cat.max_stock
        discount = choose_discount(sample_stock, cat.max_stock, cat.max_discount)
        category_data.append([cat.name, sample_stock, discount])
    return category_data

In [None]:
z = generate_data_single_cat(Category('hat', 100, .20), 10)

In [None]:
z = np.asarray(z)

In [None]:
print(z[:,2])

In [None]:
def generate_data_all_cats(categories, samples_per):
    """
    generates 'samples_per' datapoints for each category inputted
    """
    data = []
    for cat in categories:
        subset = generate_data_single_cat(cat, samples_per)
        data += subset
    return data

In [None]:
cats = [Category('Bottle', 100, .2),  # up to .2 discount
        Category('Pen', 1000, .5),  # up to .5
        Category('Clothing', 500, .4),  # up to .4
        Category('Drink', 100, .2),  # up to .2 discount
        Category('Footwear', 50, .3)  # up to .3 discount
       ]

In [None]:
X = generate_data_all_cats(cats, 400)

In [None]:
print(len(X))
print(X[0], X[401], X[801], X[1201], X[1601])

In [None]:
columns = ['Name', 'Stock', 'Discount']
df = pd.DataFrame(X, columns=columns)
df.to_csv("discount_data/arg-synth-data_0.csv")