# Generative Adversarial Networks

Transposed Convolutions

In [1]:
import torch as t
import utils
import past
import einops
import typing
from typing import Union

In [2]:
def conv_transpose1d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
    '''Like torch's conv_transpose1d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, width)
    weights: shape (in_channels, out_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    '''

    pad_size = weights.shape[2] - 1
    pad_x = past.pad1d(x, pad_size, pad_size, 0)
    kernel_mod = weights.flip(-1)
    kernel_mod = einops.rearrange(kernel_mod, 'in_channels out_channels kernel_width -> out_channels in_channels kernel_width')
    return past.conv1d_minimal(pad_x, kernel_mod)


utils.test_conv_transpose1d_minimal(conv_transpose1d_minimal)

All tests in `test_conv1d_minimal` passed!


In [6]:
def fractional_stride_1d(x, stride: int = 1):
    '''Returns a version of x suitable for transposed convolutions, i.e. "spaced out" with zeros between its values.
    This spacing only happens along the last dimension.

    x: shape (batch, in_channels, width)

    Example: 
        x = [[[1, 2, 3], [4, 5, 6]]]
        stride = 2
        output = [[[1, 0, 2, 0, 3], [4, 0, 5, 0, 6]]]
    '''

    batch, in_channels, width = x.shape
    new_width = (width - 1) * stride + 1
    new_x = t.zeros((batch, in_channels, new_width), dtype=x.dtype)

    new_x[:, :, ::stride] = x

    return new_x

utils.test_fractional_stride_1d(fractional_stride_1d)

All tests in `test_fractional_stride_1d` passed!


In [7]:
def conv_transpose1d(x, weights, stride: int = 1, padding: int = 0) -> t.Tensor:
    '''Like torch's conv_transpose1d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, width)
    weights: shape (out_channels, in_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    '''

    #print(x[0])
    new_x = fractional_stride_1d(x, stride)
    #print(new_x[0])

    pad_size = weights.shape[2] - 1 - padding
    pad_x = past.pad1d(new_x, pad_size, pad_size, 0)
    kernel_mod = weights.flip(-1)
    kernel_mod = einops.rearrange(kernel_mod, 'in_channels out_channels kernel_width -> out_channels in_channels kernel_width')
    
    pad_x = pad_x.float()
    return past.conv1d_minimal(pad_x, kernel_mod)
    
utils.test_conv_transpose1d(conv_transpose1d)

All tests in `test_conv_transpose1d` passed!


In [8]:
IntOrPair = Union[int, tuple[int, int]]
Pair = tuple[int, int]

def force_pair(v: IntOrPair) -> Pair:
    '''Convert v to a pair of int, if it isn't already.'''
    if isinstance(v, tuple):
        if len(v) != 2:
            raise ValueError(v)
        return (int(v[0]), int(v[1]))
    elif isinstance(v, int):
        return (v, v)
    raise ValueError(v)

def fractional_stride_2d(x, stride_h: int, stride_w: int):
    '''
    Same as fractional_stride_1d, except we apply it along the last 2 dims of x (width and height).
    '''
    batch, in_channels, height, width = x.shape
    new_width = (width - 1) * stride_w + 1
    new_height = (height - 1) * stride_h + 1
    new_x = t.zeros((batch, in_channels, new_height, new_width))

    new_x[:, :, ::stride_h, ::stride_w] = x

    return new_x

def conv_transpose2d(x, weights, stride: IntOrPair = 1, padding: IntOrPair = 0) -> t.Tensor:
    '''Like torch's conv_transpose2d using bias=False

    x: shape (batch, in_channels, height, width)
    weights: shape (out_channels, in_channels, kernel_height, kernel_width)


    Returns: shape (batch, out_channels, output_height, output_width)
    '''
    stride: Pair = force_pair(stride)
    padding: Pair = force_pair(padding)

    #print(x[0])
    new_x = fractional_stride_2d(x, stride[0], stride[1])
    #print(new_x[0])

    pad_size_h = weights.shape[-2] - 1 - padding[0]
    pad_size_w = weights.shape[-1] - 1 - padding[1]
    pad_x = past.pad2d(new_x, pad_size_w, pad_size_w, pad_size_h, pad_size_h,  0)
    kernel_mod = weights.flip(-1)
    kernel_mod = kernel_mod.flip(-2)

    kernel_mod = einops.rearrange(kernel_mod, 'in_channels out_channels kernel_height kernel_width -> out_channels in_channels kernel_height kernel_width')
    
    pad_x = pad_x.float()
    return past.conv2d_minimal(pad_x, kernel_mod)
    pass

utils.test_conv_transpose2d(conv_transpose2d)

All tests in `test_conv_transpose2d` passed!
