Skip to content

Writing memory format aware operators

Vitaly Fedyunin edited this page Mar 23, 2020 · 6 revisions

Memory format aware operators are the operators which satisfy two requirements:

  • they generate output in same memory format as inputs
  • they use the most efficient kernels for each different memory formats

Let say we want to add/modify operator to support torch.channels_last memory format.

in_tensor = x.contiguous(memory_format=torch.channels_last)
out_tensor = torch.operator(in_tensor) 
print(out_tensor.is_contiguous(memory_format=torch.channels_last)) # True

To do so, we need to modify the operator's CPP code. An old version of operator might look similar to this:

auto output_tensor = at::empty_like(input_tensor);
// .... standard kernel for contiguous or strided tensors
return output_tensor;

The preferred way of writing memory format aware operators is to use the switch operator. This approach allows us to expand memory formats support in the future.

// ...
auto memory_format = input_tensor.suggest_memory_format();
auto output_tensor = at::empty(output_shape, memory_format);

switch (memory_format) {
  case MemoryFormat::ChannelsLast: {
    auto input_cl_contiguous = input_tensor.contiguous(
        MemoryFormat::ChannelsLast); // if kernel requires memory dense
                                     // tensor
    // .... kernel code
    break;
  }
  case MemoryFormat::Contiguous: {
    // .... standard kernel for contiguous or strided tensors
    break;
  }
  default:
    TORCH_CHECK(
        false,
        "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
// ...

Important to learn that suggest_memory_format is not similar to input_tensor.is_contiguous(...), see function comments.

More memory format handling required when you are writing _out operator implementation.

in_tensor = x.contiguous(memory_format=torch.channels_last)
out_tensor = o.contiguous(memory_format=torch.contiguous_format)
torch.operator(in_tensor, out=out_tensor) 
print(out_tensor.is_contiguous(memory_format=torch.contiguous_format)) # True

Keeping the memory format of the output is essential. However, some performant algorithms require matching formats of inputs and outputs. In this case, it is possible to do a copy_ trick.

Tensor self_or_new_memory_format(Tensor& self, MemoryFormat memory_format) {
    if (self.is_contiguous(memory_format)) {
        return self;
    }
    return at::empty_like(self, self.options(), memory_format);
}
// ...
auto memory_format = input_tensor.suggest_memory_format();

assert_no_internal_overlap(output);
if (output_shape != output.sizes()) {
    output.resize_(output_shape, memory_format);
}

auto temporary_output_tensor = self_or_new_memory_format(output, memory_format); 

switch (memory_format) {
  case MemoryFormat::ChannelsLast: {
    auto input_cl_contiguous = input_tensor.contiguous(
        MemoryFormat::ChannelsLast); // if kernel requires memory dense
                                     // tensor
    // .... kernel code
    break;
  }
  case MemoryFormat::Contiguous: {
    // .... standard kernel
    break;
  }
  default:
    TORCH_CHECK(
        false,
        "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}

if (!output.is_same(temporary_output_tensor)) {
    output.copy_(temporary_output_tensor);
}
// ...

In some cases, there is no performant algorithm for contiguous or channels last inputs, so the same trick with temporary tensors and copy_ can be applied.

// ...
auto memory_format = input_tensor.suggest_memory_format();

assert_no_internal_overlap(output);
if (output_shape != output.sizes()) {
    output.resize_(output_shape, memory_format);
}

auto temporary_output_tensor = self_or_new_memory_format(output, MemoryFormat::ChannelsLast);
auto input_cl_contiguous = input_tensor.contiguous(MemoryFormat::ChannelsLast); 
// .... channels last kernel code
 
if (!output.is_same(temporary_output_tensor)) {
    output.copy_(temporary_output_tensor);
}
// ...

Or you can do hard exit with unsupported memory format message (this is least preferred way, and we consider such operators incomplete).

// ...
switch (memory_format) {
  case MemoryFormat::ChannelsLast: {
    auto input_cl_contiguous = input_tensor.contiguous(
        MemoryFormat::ChannelsLast); // if kernel requires memory dense
                                     // tensor
    // .... kernel code
    break;
  }
  case MemoryFormat::Contiguous:
  default:
    TORCH_CHECK(
        false,
        "Unsupported memory format. Supports only ChannelsLast");
}
// ...

Please do not forget to cover all scenarios with unit tests. We had seen countless cases when simple test saved hours of debugging.

Clone this wiki locally