In [None]:
#@title **FUNCTIONS THAT HELP TIMECODED MANIPULATIONS**
import numpy as np

# SIMPLE PROMPT CHECKS

# The simple function here counts your tokens
# before you pass yur prompts, so that you make
# sure you don't have modifiers disabled without
# realising.
from transformers import GPT2TokenizerFast
my_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

def count_tokens(prompt):
  n_token = len(my_tokenizer(prompt)['input_ids'])
  return n_token

# You can validate all prompts in a timecoded 
# prompt dict and make sure no modifier will
# be dropped.
def validate_prompts(prompts, max_token=75):
  for k in prompts:
    p = prompts[k]
    n = count_tokens(p)
    if(n > max_token):
      print("ERROR : Prompt is more than " + str(max_token) + " tokens. Prompt : " + str(p))
      print("Total : " + str(n) + " tokens")
      assert False

# You can print prompts in a more easily readable way.
def print_prompts(p):
  for key in p:
    print(key, p[key])
#    print("Tokens : " + str(count_tokens(p[key])))


# FUNCTIONS THAT GENERATE / MANIPULATE PROMPT DICTS

# This function helps to generate a timecoded dict
# nased on an input list of prompts, from 0 to 'max_t'
# for each step of size 'interval'.
def generate_prompts(prompt_list, interval=10, max_t=1000, shuffle_list=False):
  new_prompts = {}
  t = 0
  while(t < max_t):
    if(shuffle_list):
      np.random.shuffle(prompt_list)

    for p in prompt_list:
      new_prompts[t] = p
      t += interval

  return new_prompts


# This function merges two timecoded dicts of prompts.
# Merge in the sense that:
# prompts_1 has some keywords, prompts_2 has completely different 
# keywords that are COMPLEMENTARY to the words of prompts_1.
# For example, maybe prompts_1 has some main keywords and
# prompts_2 contains modifiers for style. As a result, we
# get a timecoded dict of both the keywords and style.
# The two inputs DON'T HAVE TO have the same keys (timecodes), 
# and don't have to be the same size.
# The merging will take into account the overlaps, changes. etc.
def merge_timecodes(prompts_1, prompts_2):
  new_prompts = {}
  keys_1 = list(prompts_1.keys())
  keys_2 = list(prompts_2.keys())
  times = np.sort(np.concatenate((keys_1, keys_2)))

  i1 = 0
  i2 = 0
  k1 = keys_1[i1]
  k2 = keys_2[i2]
  p1 = prompts_1[k1]
  p2 = prompts_2[k2]

  max_t = times[-1]

  for i, t in enumerate(times):
    new_prompts[t] = p1 + ", " + p2

    if(i < (len(times) - 1)):
      next_time = times[i+1]
      if(i1 < (len(keys_1)-1)) and (next_time >= keys_1[i1+1]):
        i1 += 1
      if(i2 < (len(keys_2)-1)) and (next_time >= keys_2[i2+1]):
        i2 += 1

      k1 = keys_1[i1]
      k2 = keys_2[i2]
      p1 = prompts_1[k1]
      p2 = prompts_2[k2]

  return new_prompts

# This function loops over a list of N styles and
# for each prompt in prompts, replicates the 
# original N times, having a copy with each style.
# The result will be N times large as the original
# prompts. The input prompts dict is modified in place.
# This function assumes there are sufficient
# integer times for all styles between two prompts.
# Otherwise it will overwrite some styles.
def loop_styles(prompts, styles):
    n_styles = len(styles)

    times = list(prompts.keys())
    for i, time in enumerate(times):

      if(shuffle_styles):
        np.random.seed(i + run_id)
        np.random.shuffle(styles)

      if(i < len(times)-1):
        next_time = times[i+1]
        interval = next_time - time
        step = int(interval/n_styles)
      bare_prompt = prompts[time]

      for j, style in enumerate(styles):
        prompts[time + j*step] = bare_prompt + "," + style


# This function adds a given modifier to the list of indices
# in the prompt list. If the prompt list is a dict(), the
# 'target_indices' arg must be a list of keys.
def add_modifier(prompt_list, target_indices, modifier):
  for i in target_indices:
    prompt_list[i] = prompt_list[i] + "," + modifier




# This function generates a set of timecoded values with
# plateaus around certain input 'times'. Other than around these 
# 'times', it will have the nominal_value, and when it approaches 
# one of the 'times' it will approach the transition_value,
# based on how wide the transition_range is.
def generate(times, nominal_value=0.75, transition_value=0.6, transition_range=10, time_offset=0):
    str_dict = {}
    for i, time in enumerate(times):
      t = time - time_offset
      if(t < 0):
        t = 0
      if(i == 0):
        str_dict[t] = nominal_value
      else:
        if(t > 2*transition_range):
          str_dict[t-(2*transition_range)] = nominal_value
          str_dict[t-(transition_range)] = transition_value
        str_dict[t+(transition_range)] = transition_value
        str_dict[t+(2*transition_range)] = nominal_value

    return str_dict


# This function converts a dict into a string
# that can be passed into Animation Parameters.
def convert_to_str(str_dict, add_brackets=False):
    result = ""
    times = list(str_dict.keys())
    times = np.sort(times)
    for time in times:
      s = str_dict[time]
      if(add_brackets):
        result = result + str(time) + ":(" + str(s) + "),"
      else:
        result = result + str(time) + ":" + str(s) + ","

    return result[:-1]


# This function generates a a set of timecoded values with
# plateaus around certain input 'times' (see generate function above) in a format
# that can be passed into Animation Parameters.
def generate_strength(times, down_value=0.6, up_value=0.75, transition_range=10, offset=0):
  d = generate(times, down_value, up_value, transition_range, offset)
  return convert_to_str(d, True)


# This function converts a timecode string into a dict, so that
# manipulations can be more easily made.
def convert_to_dict(s):
  d = {}
  items = s.split(",")
  for item in items:
    x = item.split(":")
    k = int(x[0])
    v = x[1]
    d[k] = v
  return d



# This function shifts the time of a timecode string
# by a given amount.
def shift_time(s, time_offset):
  t_string = "(t+" + str(time_offset) + ")"
  d = convert_to_dict(s)

  new_dict = {}
  times = list(d.keys())
  for i, k in enumerate(times):
    t = k - time_offset
    v = d[k]
    if(t < 0):
      if(i+1 < len(times)):
        next_time = times[i+1]
        if(next_time > time_offset):
          interval = next_time - k
          dv = d[next_time] - v
          v += dv * (-t) / interval
      t = 0
    new_dict[t] = d[k]

  new_s = convert_to_str(new_dict)
  new_s = new_s.replace("t", t_string)

  return new_s







