# How update CTA policy in distributed setting 

Original code does the following:
```python
    def update_cta_rates(self):
        x, y, policies = self.state.batch["cta_probe_batch"]
        self.ema_model.eval()
        with torch.no_grad():
            y_pred = self.ema_model(x)
            y_probas = torch.softmax(y_pred, dim=1)  # (N, C)

            for y_proba, t, policy in zip(y_probas, y, policies):                
                error = y_proba
                error[t] -= 1
                error = torch.abs(error).sum()
                self.cta.update_rates(policy, 1.0 - 0.5 * error.item())
```

In [1]:
import sys
sys.path.insert(0, "..")

In [2]:
import utils

cta = utils.get_default_cta()

supervised_train_dataset = utils.get_supervised_trainset_0_250("/tmp/cifar10/")

cta_probe_loader = utils.get_cta_probe_loader(
    supervised_train_dataset,
    cta=cta,
    batch_size=8,
    num_workers=12,
    sampler=None
)

Let's say we have WORLD_SIZE=2 and thus we have 2 batches for each rank. How we need to update CTA rates 

In [3]:
cta_probe_loader_iter = iter(cta_probe_loader)
cta_probe_batch_r1 = next(cta_probe_loader_iter)
cta_probe_batch_r2 = next(cta_probe_loader_iter)

In [4]:
x1, y1, policies1 = (
    *utils.sup_prepare_batch(cta_probe_batch_r1, utils.device, non_blocking=True),
    [utils.deserialize(p) for p in cta_probe_batch_r1['policy']]
)

In [5]:
x2, y2, policies2 = (
    *utils.sup_prepare_batch(cta_probe_batch_r2, utils.device, non_blocking=True),
    [utils.deserialize(p) for p in cta_probe_batch_r2['policy']]
)

In [7]:
len(policies1), len(policies2)

(8, 8)

In [8]:
policies1

[[OP(f='translate_y', bins=[0.6284187969367999]),
  OP(f='smooth', bins=[0.9300541326018598])],
 [OP(f='rescale', bins=[0.4410249998926222, 0.37404114665565613]),
  OP(f='rescale', bins=[0.34877247031343983, 0.80745831615098])],
 [OP(f='autocontrast', bins=[0.9504645469663052]),
  OP(f='brightness', bins=[0.5678519516876261])],
 [OP(f='invert', bins=[0.8041796089170575]), OP(f='identity', bins=[])],
 [OP(f='equalize', bins=[0.0017531965395323201]),
  OP(f='translate_y', bins=[0.8608303540742768])],
 [OP(f='blur', bins=[0.9104848104507769]),
  OP(f='posterize', bins=[0.31135027907028645])],
 [OP(f='shear_y', bins=[0.6413432131765835]),
  OP(f='translate_x', bins=[0.6387265618124591])],
 [OP(f='translate_x', bins=[0.384357732892316]),
  OP(f='invert', bins=[0.2567598188184105])]]

In [9]:
policies2

[[OP(f='autocontrast', bins=[0.6284187969367999]),
  OP(f='brightness', bins=[0.9300541326018598])],
 [OP(f='shear_y', bins=[0.4410249998926222]),
  OP(f='autocontrast', bins=[0.37404114665565613])],
 [OP(f='solarize', bins=[0.34877247031343983]),
  OP(f='solarize', bins=[0.80745831615098])],
 [OP(f='shear_x', bins=[0.9504645469663052]),
  OP(f='rescale', bins=[0.5678519516876261, 0.8041796089170575])],
 [OP(f='invert', bins=[0.0017531965395323201]),
  OP(f='shear_x', bins=[0.8608303540742768])],
 [OP(f='shear_y', bins=[0.9104848104507769]),
  OP(f='solarize', bins=[0.31135027907028645])],
 [OP(f='translate_x', bins=[0.6413432131765835]),
  OP(f='brightness', bins=[0.6387265618124591])],
 [OP(f='brightness', bins=[0.384357732892316]),
  OP(f='solarize', bins=[0.2567598188184105])]]

Let's store `error_per_op` as a list of packed `(op name index, num_bins, bins, error, [PAD], ..., [PAD])` for each rank. Gather all tensors `error_per_op` into a list.

In [89]:
import torch
from ctaugment import OPS


sorted_op_names = sorted(list(OPS.keys()))


def pack_as_tensor(k, bins, error, size=5, pad_value=-555.0):
    out = torch.empty(size).fill_(pad_value).to(error)
    out[0] = sorted_op_names.index(k)
    le = len(bins)
    out[1] = le
    out[2:2 + le] = torch.tensor(bins).to(error)
    out[2 + le] = error
    return out


def unpack_from_tensor(t):
    k_index = int(t[0].item())
    le = int(t[1].item())
    bins = t[2:2 + le].tolist()
    error = t[2 + le].item()
    return sorted_op_names[k_index], bins, error
    

def get_error_per_op(policies):
    error_per_op = []
    y_probas = torch.rand(len(policies), 10).to(utils.device)
    y = torch.randint(0, 10, size=(len(policies), )).to(utils.device)
    for y_proba, t, policy in zip(y_probas, y, policies):
        error = y_proba
        error[t] -= 1
        error = torch.abs(error).sum()
        for k, bins in policy:            
            error_per_op.append(pack_as_tensor(k, bins, error))
    return torch.stack(error_per_op)

In [90]:
error_per_op_r1 = get_error_per_op(policies1)
error_per_op_r2 = get_error_per_op(policies2)

In [92]:
error_per_op_r1

tensor([[ 1.8000e+01,  1.0000e+00,  6.2842e-01,  7.1072e+00, -5.5500e+02],
        [ 1.5000e+01,  1.0000e+00,  9.3005e-01,  7.1072e+00, -5.5500e+02],
        [ 1.0000e+01,  2.0000e+00,  4.4102e-01,  3.7404e-01,  4.4467e+00],
        [ 1.0000e+01,  2.0000e+00,  3.4877e-01,  8.0746e-01,  4.4467e+00],
        [ 0.0000e+00,  1.0000e+00,  9.5046e-01,  5.6430e+00, -5.5500e+02],
        [ 2.0000e+00,  1.0000e+00,  5.6785e-01,  5.6430e+00, -5.5500e+02],
        [ 8.0000e+00,  1.0000e+00,  8.0418e-01,  5.0689e+00, -5.5500e+02],
        [ 7.0000e+00,  0.0000e+00,  5.0689e+00, -5.5500e+02, -5.5500e+02],
        [ 6.0000e+00,  1.0000e+00,  1.7532e-03,  4.8191e+00, -5.5500e+02],
        [ 1.8000e+01,  1.0000e+00,  8.6083e-01,  4.8191e+00, -5.5500e+02],
        [ 1.0000e+00,  1.0000e+00,  9.1048e-01,  4.3740e+00, -5.5500e+02],
        [ 9.0000e+00,  1.0000e+00,  3.1135e-01,  4.3740e+00, -5.5500e+02],
        [ 1.4000e+01,  1.0000e+00,  6.4134e-01,  5.5784e+00, -5.5500e+02],
        [ 1.7000e+01,  1.

In [94]:
for t in error_per_op_r1:
    print(unpack_from_tensor(t))

('translate_y', [0.6284188032150269], 7.1072211265563965)
('smooth', [0.9300541281700134], 7.1072211265563965)
('rescale', [0.4410249888896942, 0.3740411400794983], 4.44674825668335)
('rescale', [0.34877246618270874, 0.8074583411216736], 4.44674825668335)
('autocontrast', [0.9504645466804504], 5.643031120300293)
('brightness', [0.5678519606590271], 5.643031120300293)
('invert', [0.8041796088218689], 5.068888187408447)
('identity', [], 5.068888187408447)
('equalize', [0.0017531965859234333], 4.819091796875)
('translate_y', [0.8608303666114807], 4.819091796875)
('blur', [0.910484790802002], 4.374037742614746)
('posterize', [0.3113502860069275], 4.374037742614746)
('shear_y', [0.6413432359695435], 5.578389644622803)
('translate_x', [0.638726532459259], 5.578389644622803)
('translate_x', [0.3843577206134796], 5.851930141448975)
('invert', [0.2567598223686218], 5.851930141448975)


In [93]:
error_per_op_r2

tensor([[ 0.0000e+00,  1.0000e+00,  6.2842e-01,  5.3517e+00, -5.5500e+02],
        [ 2.0000e+00,  1.0000e+00,  9.3005e-01,  5.3517e+00, -5.5500e+02],
        [ 1.4000e+01,  1.0000e+00,  4.4102e-01,  5.9118e+00, -5.5500e+02],
        [ 0.0000e+00,  1.0000e+00,  3.7404e-01,  5.9118e+00, -5.5500e+02],
        [ 1.6000e+01,  1.0000e+00,  3.4877e-01,  3.3680e+00, -5.5500e+02],
        [ 1.6000e+01,  1.0000e+00,  8.0746e-01,  3.3680e+00, -5.5500e+02],
        [ 1.3000e+01,  1.0000e+00,  9.5046e-01,  5.4213e+00, -5.5500e+02],
        [ 1.0000e+01,  2.0000e+00,  5.6785e-01,  8.0418e-01,  5.4213e+00],
        [ 8.0000e+00,  1.0000e+00,  1.7532e-03,  6.0290e+00, -5.5500e+02],
        [ 1.3000e+01,  1.0000e+00,  8.6083e-01,  6.0290e+00, -5.5500e+02],
        [ 1.4000e+01,  1.0000e+00,  9.1048e-01,  6.6215e+00, -5.5500e+02],
        [ 1.6000e+01,  1.0000e+00,  3.1135e-01,  6.6215e+00, -5.5500e+02],
        [ 1.7000e+01,  1.0000e+00,  6.4134e-01,  4.7493e+00, -5.5500e+02],
        [ 2.0000e+00,  1.

In [87]:
import numpy as np

kl = list(OPS.keys())
all_policy_ops = []
for k in kl:
    bins = cta.rates[k]
    rnd = np.random.uniform(0, 1, len(bins))
    all_policy_ops.append(OP(k, rnd.tolist()))
all_policy_ops

[OP(f='autocontrast', bins=[0.5619854072623901]),
 OP(f='blur', bins=[0.7046307087372261]),
 OP(f='brightness', bins=[0.9352483842108498]),
 OP(f='color', bins=[0.39131998433592863]),
 OP(f='contrast', bins=[0.7128734752302066]),
 OP(f='cutout', bins=[0.3314181264088959]),
 OP(f='equalize', bins=[0.7093151749948137]),
 OP(f='invert', bins=[0.5623374776059373]),
 OP(f='identity', bins=[]),
 OP(f='posterize', bins=[0.9360352729278023]),
 OP(f='rescale', bins=[0.49353615762657277, 0.1656012491615455]),
 OP(f='rotate', bins=[0.2959720097201758]),
 OP(f='sharpness', bins=[0.2117442040912716]),
 OP(f='shear_x', bins=[0.02576970813472279]),
 OP(f='shear_y', bins=[0.44570923815749397]),
 OP(f='smooth', bins=[0.49317636039863744]),
 OP(f='solarize', bins=[0.4949358331955145]),
 OP(f='translate_x', bins=[0.4352882706872071]),
 OP(f='translate_y', bins=[0.48944830740684386])]

In [88]:
for k, bins in all_policy_ops:
    error = torch.rand(1)
    # check pack_as_tensor / unpack_from_tensor
    t = pack_as_tensor(k, bins, error)
    new_k, new_bins, new_error = unpack_from_tensor(t)
    assert new_k == k, "{} vs {}".format(new_k, k)
    assert all([abs(v1 - v2) < 1e-7 for v1, v2 in zip(new_bins, bins)]), "{} vs {}".format(new_bins, bins)
    assert new_error == error.item(), "{} vs {}".format(new_error, error.item())

Gather all

In [None]:
# tensor_list = [
#     torch.empty_like(error_per_op) 
#     for _ in range(dist.get_world_size())
# ]
# dist.all_gather(tensor_list, error_per_op)

In [95]:
tensor_list = [error_per_op_r1, error_per_op_r2]
tensor_list

[tensor([[ 1.8000e+01,  1.0000e+00,  6.2842e-01,  7.1072e+00, -5.5500e+02],
         [ 1.5000e+01,  1.0000e+00,  9.3005e-01,  7.1072e+00, -5.5500e+02],
         [ 1.0000e+01,  2.0000e+00,  4.4102e-01,  3.7404e-01,  4.4467e+00],
         [ 1.0000e+01,  2.0000e+00,  3.4877e-01,  8.0746e-01,  4.4467e+00],
         [ 0.0000e+00,  1.0000e+00,  9.5046e-01,  5.6430e+00, -5.5500e+02],
         [ 2.0000e+00,  1.0000e+00,  5.6785e-01,  5.6430e+00, -5.5500e+02],
         [ 8.0000e+00,  1.0000e+00,  8.0418e-01,  5.0689e+00, -5.5500e+02],
         [ 7.0000e+00,  0.0000e+00,  5.0689e+00, -5.5500e+02, -5.5500e+02],
         [ 6.0000e+00,  1.0000e+00,  1.7532e-03,  4.8191e+00, -5.5500e+02],
         [ 1.8000e+01,  1.0000e+00,  8.6083e-01,  4.8191e+00, -5.5500e+02],
         [ 1.0000e+00,  1.0000e+00,  9.1048e-01,  4.3740e+00, -5.5500e+02],
         [ 9.0000e+00,  1.0000e+00,  3.1135e-01,  4.3740e+00, -5.5500e+02],
         [ 1.4000e+01,  1.0000e+00,  6.4134e-01,  5.5784e+00, -5.5500e+02],
         [ 1

In [96]:
tensor_list = torch.cat(tensor_list, dim=0)

In [97]:
tensor_list

tensor([[ 1.8000e+01,  1.0000e+00,  6.2842e-01,  7.1072e+00, -5.5500e+02],
        [ 1.5000e+01,  1.0000e+00,  9.3005e-01,  7.1072e+00, -5.5500e+02],
        [ 1.0000e+01,  2.0000e+00,  4.4102e-01,  3.7404e-01,  4.4467e+00],
        [ 1.0000e+01,  2.0000e+00,  3.4877e-01,  8.0746e-01,  4.4467e+00],
        [ 0.0000e+00,  1.0000e+00,  9.5046e-01,  5.6430e+00, -5.5500e+02],
        [ 2.0000e+00,  1.0000e+00,  5.6785e-01,  5.6430e+00, -5.5500e+02],
        [ 8.0000e+00,  1.0000e+00,  8.0418e-01,  5.0689e+00, -5.5500e+02],
        [ 7.0000e+00,  0.0000e+00,  5.0689e+00, -5.5500e+02, -5.5500e+02],
        [ 6.0000e+00,  1.0000e+00,  1.7532e-03,  4.8191e+00, -5.5500e+02],
        [ 1.8000e+01,  1.0000e+00,  8.6083e-01,  4.8191e+00, -5.5500e+02],
        [ 1.0000e+00,  1.0000e+00,  9.1048e-01,  4.3740e+00, -5.5500e+02],
        [ 9.0000e+00,  1.0000e+00,  3.1135e-01,  4.3740e+00, -5.5500e+02],
        [ 1.4000e+01,  1.0000e+00,  6.4134e-01,  5.5784e+00, -5.5500e+02],
        [ 1.7000e+01,  1.

In [98]:
for t in tensor_list:
    print(unpack_from_tensor(t))

('translate_y', [0.6284188032150269], 7.1072211265563965)
('smooth', [0.9300541281700134], 7.1072211265563965)
('rescale', [0.4410249888896942, 0.3740411400794983], 4.44674825668335)
('rescale', [0.34877246618270874, 0.8074583411216736], 4.44674825668335)
('autocontrast', [0.9504645466804504], 5.643031120300293)
('brightness', [0.5678519606590271], 5.643031120300293)
('invert', [0.8041796088218689], 5.068888187408447)
('identity', [], 5.068888187408447)
('equalize', [0.0017531965859234333], 4.819091796875)
('translate_y', [0.8608303666114807], 4.819091796875)
('blur', [0.910484790802002], 4.374037742614746)
('posterize', [0.3113502860069275], 4.374037742614746)
('shear_y', [0.6413432359695435], 5.578389644622803)
('translate_x', [0.638726532459259], 5.578389644622803)
('translate_x', [0.3843577206134796], 5.851930141448975)
('invert', [0.2567598223686218], 5.851930141448975)
('autocontrast', [0.6284188032150269], 5.351713180541992)
('brightness', [0.9300541281700134], 5.351713180541992