In [3]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn import Conv2d, Linear, ConvTranspose2d
from complextensor import ComplexTensor

In [4]:
class complexLayer(nn.Module):
    '''
    Turns a pytorch layer into a complex layer. works for Linear, Conv and ConvTranspose
    '''
    def __init__(self, Layer,kwargs):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.bias = kwargs.get('bias',False)
        # turn the bias off so as to only do matrix multiplication 
        # if you leave the bias on, then the complex arithmetic does not 
        # work out correctly
        kwargs['bias'] = False
        self.f_ma = Layer(**kwargs)
        self.f_ph = Layer(**kwargs)
        self.b = None
        out_dim_keyNames = set(['out_channels', 'out_features'])
        self.outType = list(out_dim_keyNames.intersection(kwargs.keys()))[0]
        self.out_dim = kwargs[self.outType]
        if self.bias:
            b_r = np.random.randn(self.out_dim,1).astype('float32')
            b_i = np.random.randn(self.out_dim,1).astype('float32')
            z = b_r + 1j*b_i
            self.b = ComplexTensor(z)    

    def forward(self, x): 
        magnitude = self.f_ma(x.magn)
        phas = self.f_ph(x.phase)
        if self.bias:
            if self.outType == 'out_channels':
                # expand the dims
                b_m = self.b.real.reshape(1,len(self.b),1,1)
                b_p = self.b.imag.reshape(1,len(self.b),1,1)
            else:
                b_m = self.b.real.reshape(len(self.b),)
                b_p = self.b.imag.reshape(len(self.b),)
            real = real + b_m
            imaginary = imaginary + b_p
        result = torch.cat([magnitude, phas], dim=-2)
        result.__class__ = ComplexTensor
        return result
    
    def __call__(self,x):
        result = self.forward(x)
        return result

In [19]:
bz = 16
bias = True # vary this for testing purposes
x = torch.randn((bz,2,3,100,100))
x_np = x.detach().numpy()
magn = np.squeeze(x_np[:,0,:,:])
phase = np.squeeze(x_np[:,1,:,:])
z = np.concatenate([magn, phase], axis=-2)

In [23]:
ComplexTensor(z)

TypeError: __repr__ returned non-string (type numpy.ndarray)