# Test TS2Kit

In [None]:
import torch 
import numpy as np
from ts2kit import FTSHT, ITSHT, gridDH, clearTS2KitCache

In [None]:
# Uncomment to clear cache (assuming you've set the cache path in ts2kit.py)
#clearTS2KitCache()

In [None]:
# Inverse SHT, followed by forward SHT
class tshtTest(torch.nn.Module):
    def __init__(self, B):
        super(tshtTest, self).__init__()

        self.fsht = FTSHT(B)
        self.isht = ITSHT(B)

    def forward(self, x):
        
        return self.fsht(self.isht(x));
    

In [None]:
## Parameters - change me to try different things

## Bandlimit
B = 64

## Number of batch dimensions
b = 4096;

## Torch device (GPU)
device = torch.device('cuda')

In [None]:
## Initalize the module at double and floating point precision

# On first run at a given bandwidth, you should see several statements printed in the console confirming
# that various tensors have been pre-computed sucesssfully. These will be saved to the cache directory and 
# automatically loaded during subsequent initalizations

test_double = tshtTest(B).to(device)
test_float = tshtTest(B).to(device).float()

In [None]:
## Generate random SH coefficents
Psi_double = torch.view_as_complex(2*(torch.rand(b, 2*B -1, B, 2).double() - 0.5)).to(device)

for m in range(-(B-1), B):
    for l in range(0, B):
        if (l < np.abs(m)):
            Psi_double[:, m + (B-1), l] = 0.0;

Psi_float = Psi_double.clone().cfloat()

In [None]:
## Run time and reconstruction error at double precision
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
PsiR_double = test_double(Psi_double);
end.record()

torch.cuda.synchronize()

print('Run time: {:06.4f} ms'.format(start.elapsed_time(end)), flush=True);
print('error = {}'.format(torch.sum(torch.abs(Psi_double - PsiR_double))/torch.sum(torch.abs(Psi_double))))

In [None]:
## Run time and reconstruction error at floating precision
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
PsiR_float = test_float(Psi_float);
end.record()

torch.cuda.synchronize()

print('Run time: {:06.4f} ms'.format(start.elapsed_time(end)), flush=True);
print('error = {}'.format(torch.sum(torch.abs(Psi_float - PsiR_float))/torch.sum(torch.abs(Psi_float))))