# Instructions to reproduce the results on GCP

### GPU:

- Go to [AI platform notebooks](https://cloud.google.com/ai-platform-notebooks/)
- Press "Go to console". If you don't see "Go to console" you need to create
  a GCP account for free:
    -  Press "Get started for free" and create an account.
    - Open [AI platform notebooks](https://cloud.google.com/ai-platform-notebooks/)
    - Press "Go to console"
    - Press "Enable API"
    - Wait for the previous step to finish and press "GO TO INSTANCES PAGE"
- Press "New Insance" at the top -> "Customize instance".
- Create a new instance with the following specifications:
    - Region: **us-central1**
    - Zone: **us-central1-b** (this is important for TPU)
    - Environment: TensorFlow 2.1 Enterprise
    - Machine type: `n1-highcpu-96` (96 vCPUs, 86.4 GB RAM)
    - GPU Type: NVIDIA TESLA v100 GPU
    - Tick "install NVIDIA GPU driver"
    - Press create
    - Press JUPYTERLAB and upload this notebook


### TPU:
- Follow the steps from from above but the GPU set up may be omitted
- Go to [Compute Engine -> TPUs](https://pantheon.corp.google.com/compute/tpus)
- Follow the hints to create a TPU node with the following specs:
     -  Zone: **us-central1-b**
     - TPU type: `v2-8` (you can use `v3-8` for better performance)
     - TPU software version: `nightly`
     - Press "create"
- Wait for the node to start up. Take a note of the internal IP (something like `10.245.84.146`)
- Open the JUPYTERLAB created by the steps above.




# Notes

In this colab we perform CVA-like calculation for a batch of vanilla interest 
rate swaps. We assume an underlying Hull-White model for a short rate.
We propagate the rate for 136 steps and price 1 million swaps at each iteration.
The swaps have tenures of up to 30 years with payment frequencies varying from
1 to 12 months. 

The whole procedure takes **3 seconds** on a **TPU**. This is currently done in a single precision but should be shortly be available in double precision too. 

We compare sampling speed against CPU and GPU and provide a reference to QuantLib sampling speed.


# Hull White future yield curves.

For the single factor Hull white model, the conditional forward bond prices are of the [affine form](https://en.wikipedia.org/wiki/Hull%E2%80%93White_model#Bond_pricing_using_the_Hull%E2%80%93White_model):

$$P(S, T) = A(S, T) e^{-B(S,T) r(S)}$$

where

$$B(S, T) = \frac{1}{\alpha} \left(1 - e^{-\alpha (T-S) } \right)$$

and

$$\begin{eqnarray}
\ln A(S, T) &=& \ln \frac{P(0, T)}{P(0, S)} + B(S, T) f(0, S) - \frac{\sigma^2}{4\alpha^3}\left[1-e^{-\alpha (T-S)}\right]^2 (1-e^{-2\alpha S}) \\
&=& \ln \frac{P(0, T)}{P(0, S)} + B(S, T) f(0, S) - \frac{\sigma^2}{4\alpha}B(S,T)^2 (1-e^{-2\alpha S})
\end{eqnarray}
$$

and
$$f(0, S) = -\frac{\partial}{\partial S} \ln P(0, S)$$

Assuming we are given the pair $(S, r(S))$, we can use the above to compute the discount factors as "observed" at time $S$. The set of future times will be given to us and the $r$ at those times will be computed by sampling (see next section.).t those times will be computed by sampling (see next section.).

In [1]:
#@title Imports
import numpy as np
import tensorflow.compat.v2 as tf
import time
!pip install --upgrade tf_quant_finance -q
# Disable eager execution since TPU routines are better handled in graph mode
tf.compat.v1.disable_eager_execution()
import tf_quant_finance as tff

# Load TFF dates library
dates = tff.experimental.dates

In [2]:
#@title Global dtype. TPU will soon support FP64
dtype = np.float32 #@param

In [3]:
#@title TFF Fixing funtion and VanillaSwap class

def get_fixings(*, dates_tensor, discount_fn, day_count, dtype):
  """Computes fixings implied by the input dates tensor and a discounting curve.

  Given dates `[d1, d2, .. , dn]` computes forward rates `fwd_rates` between
  `[d_i, d_{i+1}]` for `i=0,..., n-1` and returns the corresponding deposit
  rates defined as `1 + deposit_rates * day_count(d_i, d_{i+1}) = fwd_rates`.

  Args:
    dates_tensor: `DateTensor` of shape `[num_dates, batch_shape]`.
    discount_fn: A callable that maps `DateTensor` to a real number of `dtype`
      which corresponds to discounting.
    day_count: A daycounting function.
    dtype: Output `dtype`.
  
  Returns:
    A `Tensor` of the specified `dtype` and of shape
    `[num_dates - 1, batch_shape]` that correponds to the deposit rates at
    `dates_tensor[1:]`.
  """ 
  start_dates = dates_tensor[:-1]
  end_dates = dates_tensor[1:]
  disc_start = discount_fn(start_dates)
  disc_end = discount_fn(end_dates)
  t = day_count(start_date=start_dates, end_date=end_dates, dtype=dtype)
  if t.shape.as_list() != disc_end.shape.as_list():
    t = tf.expand_dims(t, axis=-1)
  fixings = tf.where(t > 1e-8,
                     (disc_start/disc_end - 1.0) / t,
                     0.0)
  fixings = tf.where(t > 0, fixings, 0.0)
  return fixings

class VanillaSwap:
  """Simple interest rate swap."""
  def __init__(self,
               *,
               calc_date,
               fixed_leg_dates,
               float_leg_dates,
               fixed_leg_rates,
               float_leg_rates,
               notional,
               day_count,
               discount_fn,
               dtype=None):
    """Initializer.
    
    Args:
      calc_date: An instance of `DateTensor` of zero shape. The reference date
        to which perform the discounting.
      fixed_leg_dates: A `DateTensor` of shape `[batch_shape, n]` representing
        the cashflow dates of the fixed leg including the `calc_date` as the
        first entry for each swap in the batch.
      float_leg_dates: A `DateTensor` of shape `[batch_shape, n]` representing
        the cashflow dates of the float leg including the `calc_date` as the
        first entry for each swap in the batch. 
      fixed_leg_rates: A real `Tensor` of shape brodcastable with
        `[batch_shape, n]` representing the fixed rates of the swap.
      float_leg_rates: A real `Tensor` of shape brodcastable with
        `[batch_shape, n]` and of the same dtype as `fixed_leg_rates`.
        Represents the float rates of the swap.
      notional: A real `Tensor` of shape brodcastable with `[batch_shape, n]`
        and of the same dtype as `fixed_leg_rates`. Represents the notional of
        the swap.
      day_count: A daycount convention. One of `dates.daycounts`.
      discount_fn: A callable that maps `DateTensor` to a real number of the
        same `dtype` as `fixed_leg_rates` which corresponds to the discounting
        function.
      dtype: A `dtype` for the underlying real `Tensor`s.
        Default value: None which maps to the `dtype` inferred by TensorFlow.
    """  
    self._calc_date = calc_date
    self._fixed_leg_dates= fixed_leg_dates
    self._float_leg_dates = float_leg_dates
    self._fixed_leg_rates = fixed_leg_rates
    self._float_leg_rates = float_leg_rates
    self._notional = notional
    self._day_count = day_count
    self._discount_fn = discount_fn
    self._dtype = dtype

  def fixed_cashflows(self):
    """Returns all fixed cashflows at `fixed_leg_dates`."""
    start_dates = self._fixed_leg_dates[:-1]
    end_dates = self._fixed_leg_dates[1:]
    t = self._day_count(
        start_date=start_dates, end_date=end_dates, dtype=self._dtype)
    t = tf.where(t > 0, t, 0)
    if t.shape.as_list() != self._float_leg_rates.shape.as_list():
      t = tf.expand_dims(t, axis=-1)
    return self._notional * self._fixed_leg_rates * t

  def float_cashflows(self):
    """Returns all float cashflows at `float_leg_dates`."""
    start_dates = self._float_leg_dates[:-1]
    end_dates = self._float_leg_dates[1:]
    t = self._day_count(
        start_date=start_dates, end_date=end_dates, dtype=self._dtype)
    t = tf.where(t > 0, t, 0)
    if t.shape.as_list() != self._float_leg_rates.shape.as_list():
      t = tf.expand_dims(t, axis=-1)
    return self._notional * self._float_leg_rates * t

  def float_leg_present_value(self):
    """Returns the value of the float leg discounted to `self.calc_date`."""
    payment_dates = self._float_leg_dates[1:]
    cashflows = self.float_cashflows()
    return tf.reduce_sum(cashflows * self._discount_fn(payment_dates),
                         axis=0)

  def fixed_leg_present_value(self):
    """Returns the value of the fixed leg discounted to `self.calc_date`."""
    payment_dates = self._fixed_leg_dates[1:]
    cashflows = self.fixed_cashflows()
    return tf.reduce_sum(cashflows * self._discount_fn(payment_dates),
                         axis=0)

  def price(self):
    """Returns the value of the swap discounted to `self.calc_date`."""
    return self.float_leg_present_value() - self.fixed_leg_present_value()


In [4]:
#@title Swap schedule generation and Hull-White model parameters

from collections import namedtuple

YieldParams = namedtuple('YieldParams', ['a0', 'a1', 'a2'])

def random_yield_params(size, max_time=30.0):
  r0 = dtype(np.random.rand(size)* (0.07 - 0.005) + 0.005)
  rT = dtype(np.random.rand(size)* (0.07 - 0.005) + 0.005)
  r_mins = np.minimum(r0, rT)
  r_maxs = np.maximum(r0, rT)
  do_low = np.random.rand(size) > dtype(0.5)
  a0 = np.random.rand(size)
  a0 = np.where(do_low, a0 * (r_mins - 0.0001) + 0.0001, a0 * (0.08 - r_maxs) + r_maxs)
  a2 = max_time / (1 + np.random.choice([-1.0, 1.0], size=size) * np.sqrt((rT - a0)/(r0-a0)))
  a1 = (r0 - a0) / a2 / a2
  return YieldParams(a0=a0,a1=a1,a2=a2)

def log_current_discount_fwd_fn(yield_params):
  """Suitable for the next log discount evaluator below."""
  a0 = np.array(yield_params.a0, dtype=dtype)
  a1 = np.array(yield_params.a1, dtype=dtype)
  a2 = np.array(yield_params.a2, dtype=dtype)
  def eval_fn(times):
    """Gives the log zero coupon bond price and the instantaneous forward rate."""
    return -(a0 + a1 * (times - a2)**2) * times, (a0 - a1 * a2 * a2 / 3) + 3 * a1 * (times - 2 * a2 / 3) ** 2
  return eval_fn


HullWhiteData = namedtuple('HullWhiteData',
                           ['mean_reversion', 'volatility',
                            'log_discount_fwd_fn'])


def gen_hull_white_params():
  mean_reversion = np.random.rand() * 0.1
  volatility = np.random.rand() * 0.3
  present_yield_curve_params = random_yield_params(1)
  return HullWhiteData(
      mean_reversion=mean_reversion,
      volatility=volatility,
      log_discount_fwd_fn=log_current_discount_fwd_fn(present_yield_curve_params))
  
def generate_short_rates(initial_short_rates,
                         hull_white_params, times, num_scenarios,
                         dtype):
  a = dtype(hull_white_params.mean_reversion)
  sigma = dtype(hull_white_params.volatility)
  def instant_forward_rate_fn(t):
    return hull_white_params.log_discount_fwd_fn(t)[1]
  process = tff.models.hull_white.HullWhiteModel1F(
      mean_reversion=a, volatility=sigma,
      instant_forward_rate_fn=instant_forward_rate_fn,
      dtype=dtype)
  sample_paths = process.sample_paths
  paths = sample_paths(
      times,
      num_samples=num_scenarios,
      initial_state=initial_short_rates,
      seed=42)
  paths = tf.squeeze(paths, axis=-1)
  # Shape [num_times, num_samples]
  return tf.transpose(paths)

def discount_curve_at_times(
    calc_date,
    short_rates,
    start_date,
    end_date,
    current_disc_fwd_fn,
    day_count,
    mean_revs,
    sigmas):
  """Computes forward discount factors.

  Produces P(S, T) i.e. the discount factor as seen at time 'S' the calculation date
  for expiry at time 'T' the evaluation date.

  Time today is 0. The eval times are allowed to be negative but not the calc times

  Args:
    calc_date: current date.
    short_rates: The short rates of shape: [N_scenarios].
    start_date: The calculation time.
    end_date: The evaluation dates.
    currency_disc_fwd_fn: A callable that returns instantaneous forward rate and
      the log-discount at the specified times. 
    mean_revs: A tensor of shape [num_currencies] The mean reversions for
      the HW model.
    sigmas: A tensor of same shape as mean_revs.
  
  Returns: A tensor of shape [end_date.shape] + [N_scenarios]
  """
  # S- T
  day_fractions = day_count(
      start_date=start_date,
      end_date=end_date, dtype=dtype)
  S = day_count(
      start_date=calc_date,
      end_date=start_date, dtype=dtype)
  T = day_count(
      start_date=calc_date,
      end_date=end_date, dtype=dtype)

  b_exp = mean_revs * day_fractions  # shape [N_calc_dates, n_eval_dates]
  b = (1 - tf.exp(-b_exp)) / mean_revs
  lnP_p = current_disc_fwd_fn(T)[0]  # shape [n_eval_dates]
  lnP_den, inst_fwd = current_disc_fwd_fn(S)  # output of shapes [n_calc_dates, n_eval_dates]
  lnA = (lnP_p - lnP_den + b * inst_fwd
         - ((sigmas * b) ** 2) * (1 - tf.exp(-2 * mean_revs * S)))
  lnA = tf.expand_dims(lnA, axis=-1)
  b = tf.expand_dims(b, axis=-1)
  short_rates = tf.expand_dims(short_rates, axis=0)
  discounts = tf.exp(lnA - b * short_rates)
  # Adjust for when calc_date > eval_date
  return tf.where(tf.expand_dims(T > S, -1), discounts, 1.0)


# Pricing vaniall swaps comparison. CPU vs GPU vs TPU

Note that **QuantLib** pricing speed for a swap with 40 payments:
**10000 swaps / sec** on a Intel . TPU can price **2 million / sec**

In [5]:
tf.compat.v1.reset_default_graph()
hull_white_params = gen_hull_white_params()
calc_date = dates.from_year_month_day(2015, 9, 9)
# Corresponding ql.WeekendsOnly calendar
calendar = dates.HolidayCalendar2(
    weekend_mask=dates.WeekendMask.SATURDAY_SUNDAY)
# Business day convention
bussiness_convention = dates.BusinessDayConvention.FOLLOWING
day_count = dates.daycounts.actual_360
settlement_days = 2

In [6]:
# Generate swaps
TENORS = [1, 3, 6, 12] # months

NUM_SWAPS = 100 #@param
NUM_CALCULATION_DATES = 136 #@param

NUM_SWAPS_PER_TENOR = NUM_SWAPS // len(TENORS)
# Start date of the swaps
start_dates = []
end_dates = []
schedule_dates = []
long_short = []

for t in TENORS:
  random_shift = np.int32(t * 2 *
                          (np.random.rand(NUM_SWAPS_PER_TENOR) - 0.5) * 30)
  start_date = calendar.add_business_days(
      calc_date, random_shift, roll_convention=bussiness_convention)
  # Maximum swap tenure is 30 years
  swap_tenure = 1 + np.int32(30 * (np.random.rand(NUM_SWAPS_PER_TENOR)))
  period = dates.periods.PeriodTensor(swap_tenure, dates.PeriodType.YEAR)
  end_date = calendar.add_period_and_roll(start_date, period,
                                          bussiness_convention)
  # Exchange schedules
  schedule = dates.PeriodicSchedule(
      start_date=start_date, end_date=end_date,
      tenor=dates.periods.months(t),
      holiday_calendar=calendar,
      roll_convention=bussiness_convention)
  schedule_date = schedule.dates()
  # Record swap data
  start_dates.append(start_date)
  end_dates.append(end_date)
  schedule_dates.append(schedule_date.transpose())
  long_short.append(2 * (np.random.binomial(1, 0.52, size=random_shift.shape) - 0.5))

In [7]:
def discount_curve(dates):
  day_fractions = day_count(
      start_date=calc_date,
      end_date=dates)
  def discount_curve(t):
    return tf.exp(hull_white_params.log_discount_fwd_fn(t)[0])
  discounts = discount_curve(day_fractions)
  return tf.where(dates > calc_date, discounts, 1.0)



In [8]:
swaps = []

for schedule in schedule_dates:
  # Compute the float let rates
  float_leg_rates = get_fixings(
          dates_tensor=schedule,
          day_count=day_count,
          discount_fn=discount_curve,
          dtype=dtype)

  swap = VanillaSwap(
      calc_date=calc_date,
      fixed_leg_dates=schedule,
      float_leg_dates=schedule,
      fixed_leg_rates=0.0039,
      float_leg_rates=float_leg_rates,
      notional=1000000,
      day_count=day_count,
      discount_fn=discount_curve,
      dtype=dtype)
  swaps.append(swap)

In [9]:
def swap_price():
  return [position * swap.price()
          for swap, position in zip(swaps, long_short)]

### Speed comparison accross platforms

In [10]:
tpu_price = tf.compat.v1.tpu.rewrite(swap_price)

In [11]:
# Prepare session (resolves cluster settings)
internal_ip = "10.141.97.250"
sess = tf.compat.v1.Session("grpc://{}:8470".format(internal_ip))
sess.run(tf.compat.v1.tpu.initialize_system())

b'\n\x03\x02\x02\x02\x10\x01\x18\x08"\x18\x00\x00\x00\x00\x00\x01\x00\x01\x00\x00\x01\x01\x01\x00\x00\x01\x00\x01\x01\x01\x00\x01\x01\x01'

In [12]:
%%timeit 
#@title TPU pricing speed
# This is for a single core. Note that TPU has 8 cores.
sess.run(tpu_price) 

5.45 ms ± 380 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
#with tf.device("/CPU:0"):
#  cpu_price = tf.xla.experimental.compile(swap_price)
#  cpu_price2 = xla.compile(swap_price)

#sess = tf.compat.v1.Session()

In [14]:
%%timeit
#@title CPU pricing speed (112 virtual CPUs, see below)
sess.run(cpu_price) 

NameError: name 'cpu_price' is not defined

In [None]:
#@title Tesla V100 GPU pricing speed
with tf.device("/gpu:0"):
  gpu_price = tf.xla.experimental.compile(swap_price)
sess = tf.compat.v1.Session()


In [None]:
#@title TPU pricing speed
#%%timeit
#sess.run(gpu_price) #  4.11 s

## CPU info

In [15]:
!cat /proc/cpuinfo

processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 63
model name	: Intel(R) Xeon(R) CPU @ 2.30GHz
stepping	: 0
microcode	: 0x1
cpu MHz		: 2300.000
cache size	: 46080 KB
physical id	: 0
siblings	: 4
core id		: 0
cpu cores	: 2
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single ssbd ibrs ibpb stibp kaiser fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs itlb_multihit
bogomips	: 4600.00
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

processor	:

# Sample 1 scenario for 1 million swaps at 136 time points on a TPU

A standard TPU with 8 cores can solve this problem in **3 seconds**.

Note that similarly, one can get access to a TPU with **512 cores**, in which case
the scenarios can be generated at a speed of **45 ms**. 

In [22]:
tf.compat.v1.reset_default_graph()
calc_date = dates.from_year_month_day(2015, 9, 9)
# Corresponding ql.WeekendsOnly calendar
calendar = dates.HolidayCalendar2(
    weekend_mask=dates.WeekendMask.SATURDAY_SUNDAY)
# Business day convention
bussiness_convention = dates.BusinessDayConvention.FOLLOWING
day_count = dates.daycounts.actual_360
settlement_days = 2

TENORS = [1, 3, 6, 12] # months

NUM_SWAPS = 1000 #@param
NUM_CALCULATION_DATES = 136 #@param
NUM_SAMPLES =  100000 #@param
NUM_TPU_CORES = 8 #@param

NUM_SWAPS_PER_TENOR = NUM_SWAPS // len(TENORS)

def generate_swap_schedules(num_swaps_per_tenor):
  # Start date of the swaps
  start_dates = []
  end_dates = []
  schedule_dates = []
  long_short = []

  for t in TENORS:
    random_shift = np.int32(t * 2 *
                            (np.random.rand(num_swaps_per_tenor) - 0.5) * 30)
    start_date = calendar.add_business_days(
        calc_date, random_shift, roll_convention=bussiness_convention)
    # Maximum swap tenure is 30 years
    swap_tenure = 1 + np.int32(30 * (np.random.rand(num_swaps_per_tenor)))
    period = dates.periods.PeriodTensor(swap_tenure, dates.PeriodType.YEAR)
    end_date = calendar.add_period_and_roll(start_date, period,
                                            bussiness_convention)
    # Exchange schedules
    schedule = dates.PeriodicSchedule(
        start_date=start_date, end_date=end_date,
        tenor=dates.periods.months(t),
        holiday_calendar=calendar,
        roll_convention=bussiness_convention)
    schedule_date = schedule.dates()
    # Record swap data
    start_dates.append(start_date)
    end_dates.append(end_date)
    schedule_dates.append(schedule_date.transpose())
    long_short.append(2 * (np.random.binomial(1, 0.501, size=random_shift.shape) - 0.5))
  return schedule_dates, long_short

In [23]:

schedule_dates = []
long_short = []

# We split swaps into 8 date tensors in order to distribute the computations
for _ in range(NUM_TPU_CORES):
  schedules, positions = generate_swap_schedules(
      NUM_SWAPS_PER_TENOR // NUM_TPU_CORES)
  schedule_dates.append(schedules)
  long_short.append(positions)

# All the calculations date at which to price the swaps
calc_times = calendar.add_business_days(calc_date, range(NUM_CALCULATION_DATES),
                                        roll_convention=bussiness_convention)
hull_white_params = gen_hull_white_params()



In [24]:
# Since storing 1 million swaps take a lot of memory, we 
# compute aggragated values for each date, i.e.,
# max(sum(swap_prices_at_date_i), 0)
def aggregated_prices_for_schedule(schedule_dates, long_short):
  def aggregated_prices_at_time(iter_num, short_rates):
    def discount_curve(dates):
      return discount_curve_at_times(
          calc_date,
          short_rates,
          calc_times[iter_num],
          dates,
          hull_white_params.log_discount_fwd_fn,
          day_count,
          hull_white_params.mean_reversion,
          hull_white_params.volatility)
    swaps = []
    for schedule in schedule_dates:
      # Compute the float let rates
      float_leg_rates = get_fixings(
              dates_tensor=schedule,
              day_count=day_count,
              discount_fn=discount_curve,
              dtype=dtype)

      swap = VanillaSwap(
          calc_date=calc_times[iter_num],
          fixed_leg_dates=schedule,
          float_leg_dates=schedule,
          fixed_leg_rates=0.0039,
          float_leg_rates=float_leg_rates,
          notional=1000000,
          day_count=day_count,
          discount_fn=discount_curve,
          dtype=dtype)
      swaps.append(swap)
    prices = [np.expand_dims(position, axis=-1) * swap.price()
              for swap, position in zip(swaps, long_short)]
    return tf.reduce_sum(prices, axis = 0)
  return aggregated_prices_at_time

In [25]:
# Distributed calculations on a TPU
aggregated_prices = tf.TensorArray(dtype, size=NUM_CALCULATION_DATES)

short_rates = generate_short_rates(
    0.02, hull_white_params,
    day_count(start_date=calc_date,
              end_date=calc_times, dtype=dtype),
    NUM_SAMPLES,
    dtype)


def cond_fn(iter_num, aggregated_prices):
  return iter_num < NUM_CALCULATION_DATES 

def body_fn(iter_num, aggregated_prices):
  all_replications = []
  for i in range(NUM_TPU_CORES):
    with tf.device("/job:tpu_worker/replica:0/task:0/device:TPU:{}".format(i)):
      all_replications.append(
          tf.compat.v1.tpu.rewrite(
              aggregated_prices_for_schedule(schedule_dates[i],
                                            long_short[i]),
          inputs=[iter_num, short_rates[iter_num]]))
  aggregated_replicas = 0
  for replica in all_replications:
    aggregated_replicas += replica[0]
  aggregated_replicas =tf.reduce_sum(aggregated_replicas, axis=0)
  aggregated_replicas = tf.maximum(aggregated_replicas, 0)
  aggregated_replicas = tf.reduce_mean(aggregated_replicas)
  aggregated_prices = aggregated_prices.write(iter_num, aggregated_replicas)
  return iter_num + 1, aggregated_prices

_, aggregated_prices = tf.while_loop(cond_fn, body_fn, (0, aggregated_prices))

summary = aggregated_prices.stack()

In [26]:
# Prepare session (resolves cluster settings)
internal_ip = "10.141.97.250"
sess = tf.compat.v1.Session("grpc://{}:8470".format(internal_ip))
sess.run(tf.compat.v1.tpu.initialize_system())

b'\n\x03\x02\x02\x02\x10\x01\x18\x08"\x18\x00\x00\x00\x00\x00\x01\x00\x01\x00\x00\x01\x01\x01\x00\x00\x01\x00\x01\x01\x01\x00\x01\x01\x01'

In [28]:
#@title TPU performance (after automatic optimizations)
t = time.time()
res_tpu = sess.run(summary)
print("wall time: ", time.time() - t)

wall time:  144.42943406105042
