In [1]:
import tensorflow as tf
import numpy as np


In [2]:
def ex1(tensor: tf.Tensor) -> tf.Tensor:
    assert tensor.dtype == 'float32', "wrong dtype"
    mask = tensor > 0
    boolean_mask = tf.boolean_mask(tensor, mask)
    transposed = tf.transpose(boolean_mask)
    multiplied = tf.matmul(
        boolean_mask[:, tf.newaxis], transposed[tf.newaxis, :])
    expanded = tf.expand_dims(multiplied, axis=0)
    repeat = tf.repeat(expanded, 5, axis=0)
    return repeat


In [3]:
t1 = tf.random.uniform([3], minval=-10, maxval=10)
t2 = tf.random.uniform([33, 7], minval=-10, maxval=10)
t3 = tf.random.uniform([2, 32, 6], minval=-10, maxval=10)

print(ex1(t1))
print(ex1(t2))
print(ex1(t3))


tf.Tensor(
[[[25.129848]]

 [[25.129848]]

 [[25.129848]]

 [[25.129848]]

 [[25.129848]]], shape=(5, 1, 1), dtype=float32)
tf.Tensor(
[[[65.47646  27.262085 36.15457  ... 57.431866 62.246357 57.635128]
  [27.262085 11.350967 15.053485 ... 23.912596 25.917181 23.997227]
  [36.15457  15.053485 19.963709 ... 31.712532 34.370983 31.824768]
  ...
  [57.431866 23.912596 31.712532 ... 50.37564  54.598614 50.553932]
  [62.246357 25.917181 34.370983 ... 54.598614 59.1756   54.79185 ]
  [57.635128 23.997227 31.824768 ... 50.553932 54.79185  50.732853]]

 [[65.47646  27.262085 36.15457  ... 57.431866 62.246357 57.635128]
  [27.262085 11.350967 15.053485 ... 23.912596 25.917181 23.997227]
  [36.15457  15.053485 19.963709 ... 31.712532 34.370983 31.824768]
  ...
  [57.431866 23.912596 31.712532 ... 50.37564  54.598614 50.553932]
  [62.246357 25.917181 34.370983 ... 54.598614 59.1756   54.79185 ]
  [57.635128 23.997227 31.824768 ... 50.553932 54.79185  50.732853]]

 [[65.47646  27.262085 36.15457  

In [4]:
def ex2(tensor: tf.Tensor) -> tf.Tensor:
    assert tensor.dtype == 'float32', "wrong dtype"
    print(tensor)
    sliced = tf.where(tf.reduce_min(tensor) > 0,
                      tensor[:, 0], tensor[:, 1]) / 8
    rounded = tf.round(sliced)
    cast = tf.cast(rounded, tf.int32)
    return cast


In [5]:
t4 = tf.constant([[5, 6], [7, 3], [4, 5]], dtype=tf.float32)
t5 = tf.constant([[-4, 5, 6], [7, 8, 9]], dtype=tf.float32)

print(ex2(t4))
print(ex2(t5))


tf.Tensor(
[[5. 6.]
 [7. 3.]
 [4. 5.]], shape=(3, 2), dtype=float32)
tf.Tensor([1 1 0], shape=(3,), dtype=int32)
tf.Tensor(
[[-4.  5.  6.]
 [ 7.  8.  9.]], shape=(2, 3), dtype=float32)
tf.Tensor([1 1], shape=(2,), dtype=int32)


In [9]:
def ex3(tensor1: tf.Tensor, tensor2: tf.Tensor) -> tf.Tensor:
    assert (tf.shape(tensor1).numpy() == tf.shape(
        tensor2).numpy()).all(), "incompatible shape"
    indices = tf.where(tf.logical_and(tensor1 < 4, tensor1 > 0))
    updated = tf.tensor_scatter_nd_update(tensor2, indices, [7, 7, 7])
    return updated


In [10]:
t6 = tf.constant([[8, 9, 0], [2, 3, 4], [6, 1, 8]], dtype='float32')
t7 = tf.constant([[4, 5, 8], [23, 32, 14], [6, 1, 15]], dtype='float32')

print(ex3(t6, t7))


tf.Tensor(
[[ 4.  5.  8.]
 [ 7.  7. 14.]
 [ 6.  7. 15.]], shape=(3, 3), dtype=float32)
