In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import numpy as np
from numpy.fft import irfft
from pathlib import Path
from torch.utils.data import DataLoader
torch.set_float32_matmul_precision('medium')

In [2]:
def get_layers(width: int, latent_dim : int, act_fn : object, encoder = True):
    sizes = []
    size = width
    while size > 30:
        sizes.append(int(size))
        size = size / 2
    sizes.append(latent_dim)

    if encoder==False:
        sizes = sizes[::-1]
    layers = []
    for layer_idx in range(len(sizes) - 1):
        layers.append(nn.Linear(sizes[layer_idx], sizes[layer_idx+1]))
        layers.append(act_fn())
    layers.pop()

    return layers

class Encoder(nn.Module):
    def __init__(self, width: int, latent_dim: int, act_fn: object = nn.ReLU):
        super().__init__()
        self.net = nn.Sequential(*get_layers(width, latent_dim, act_fn))

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self, width: int, latent_dim: int, act_fn: object = nn.ReLU):
        super().__init__()
        self.net = nn.Sequential(*get_layers(width, latent_dim, act_fn, encoder=False))
                                 
    def forward(self, x):
        return self.net(x)

In [3]:
class AutoEncoder(pl.LightningModule):
    def __init__(
        self,
        width: int,
        latent_dim: int,
        lr: float,
        encoder_class: object = Encoder,
        decoder_class: object = Decoder,
    ):
        super(AutoEncoder, self).__init__()
        self.save_hyperparameters()
        self.lr = lr
        self.width = width
        self.latent_dim = latent_dim
        self.encoder = encoder_class(width, latent_dim)
        self.decoder = decoder_class(width, latent_dim)
        self.val_loss = []

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def _get_reconstruction_loss(self, x):
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def training_step(self, batch):
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch):
        loss = self._get_reconstruction_loss(batch)
        self.val_loss.append(loss)
        self.log("val_loss", loss, sync_dist=True)

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.val_loss).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.val_loss.clear()

    def _test_plot(self, noisy:torch.Tensor, denoised:torch.Tensor, idx:int,
                   loss:float):
        noisy = noisy.cpu().numpy()
        denoised = denoised.cpu().numpy()
        plt.figure(figsize=(8, 4))
        plt.subplot(1,2,1)
        plt.plot(noisy[idx,:], label='Noisy')
        plt.plot(denoised[idx,:], label='Denoised')
        plt.legend(loc='upper right')
        plt.xlim([0, self.width-1])
        plt.ylabel('Normalized amplitude')
        plt.xlabel('index')
        plt.subplot(1,2,2)
        plt.plot(noisy[idx,:] - denoised[idx,:], label='Residual')
        plt.legend()
        plt.xlim([0, self.width-1])
        plt.xlabel('index')
        plt.suptitle(f'Loss: {loss:.7f}')
        plt.savefig(f'reports/figures/{self.width}_{self.latent_dim}_{idx}.png', dpi=200)
        plt.close()

    def test_step(self, batch):
        x_hat = self.forward(batch)
        loss = self._get_reconstruction_loss(batch)
        self.log("test_loss", loss)
        rand = np.random.random(10)*batch.size(dim=0)  # indexes for plotting
        for rand_no in rand:
            self._test_plot(batch, x_hat, int(rand_no), loss)

In [4]:
class GetData():
    def __init__(
        self,
        width: int = 128
    ):
        super().__init__()
        self.width = width
    
    def _gen_blip(self, fs:int, flow:int, fhigh:int, dt_shift:float):
        freqs = np.arange(1 + fs//2)
        spec = np.zeros(len(freqs))
        logf = np.log(freqs[flow:fhigh])
        spec1 = (logf-logf[0])*(logf[-1]-logf)
        spec[flow:fhigh] = spec1/np.max(spec1)
        spec_shifted = np.exp(-1j*freqs*2*np.pi*dt_shift)*spec
        blip = np.roll(irfft(spec_shifted), fs//2)
        return blip / np.max(blip)

    def fake_blips(self, fs: int = 512):
        dt_shifts = np.random.normal(0, 5, 50) / 1000  # in ms
        f_lows = np.linspace(15, 35, 21, dtype=int)
        f_highs = np.linspace(190, 230, 41, dtype=int)

        blips = []
        for dt_shift in dt_shifts:
            for f_low in f_lows:
                for f_high in f_highs:
                    blip = self._gen_blip(fs, f_low, f_high, dt_shift)
                    blips.append(blip)
        blips = np.array(blips)
        blips = blips[:, fs//2-self.width//2:fs//2+self.width//2]
        blips = np.array(blips).astype('float32')
        return blips

    def shift_blips(self, dset : object = np.array): 
        shift_idxs = range(-5, 6)
        shifted_blips = []
        for idx in shift_idxs:
            shifted_blips.append(np.roll(dset, idx, axis=1))
        shifted_blips = np.array(shifted_blips)
        shifted_blips = np.vstack(shifted_blips)
        return shifted_blips

    def real_blips(self, ddir : object = Path('data/external')):
        files = ddir.glob('*.npy')
        blips = []
        for blip in files:
            blip_data = np.load(blip)[:,1]
            blips.append(blip_data)
        blips = np.array(blips)
        length = len(blips.T)
        blips = blips[:, length//2-self.width//2:length//2+self.width//2]
        blips = blips.astype('float32')
        return blips

In [5]:
class GlitchDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data: str = 'real',
        batch_size: int = 100,
        width: int = 128
    ):
        super().__init__()
        self.data = data
        self.batch_size = batch_size
        self.width = width
        self.prepare_data_per_node = False

    def prepare_data(self):
        gd = GetData(self.width)
        if self.data == 'real':
            self.glitch = F.normalize(torch.from_numpy(gd.real_blips()))
        elif self.data == 'fake':
            self.glitch = F.normalize(torch.from_numpy(gd.fake_blips()))
        else:
            raise SystemExit("Works only with 'real' and 'fake' data")

    def setup(self, stage=None):
        self.glitch_train, self.glitch_val, self.glitch_test = torch.utils.data.random_split(self.glitch, [0.89, 0.1, 0.01], generator=torch.manual_seed(0))

    def train_dataloader(self):
        if self.data == 'real':
            # generate more glitches for training by shifting them
            gd = GetData(self.width)
            return DataLoader(gd.shift_blips(self.glitch_train), batch_size=self.batch_size, num_workers=4)
        else:
            return DataLoader(self.glitch_train, batch_size=self.batch_size, num_workers=4)
    def val_dataloader(self):
        return DataLoader(self.glitch_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.glitch_test, batch_size=len(self.glitch_test), num_workers=4)

In [6]:
data = 'real'
width = 128
latent_dim = 5
lr = 1e-3
epochs = 2
ae = AutoEncoder(width, latent_dim, lr)
dm = GlitchDataModule(data=data, width=width)

trainer = pl.Trainer(max_epochs=epochs, accelerator="auto")
trainer.fit(model=ae, datamodule=dm)
trainer.save_checkpoint("example.ckpt")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

In [None]:
model = AutoEncoder.load_from_checkpoint("example.ckpt")
dm = GlitchDataModule(data='real', width=model.width)
trainer = pl.Trainer()
trainer.test(datamodule=dm, model=model, ckpt_path="example.ckpt")