In [1]:
# Makes it easier to debug
%load_ext autoreload
#%load_ext line_profiler
%autoreload 2

In [2]:
import pylab as plt
import numpy as np
import torch
import torch.nn as nn
import typing as tp
from torch.nn import functional as F

import swyft
from swyft.lightning.components_v1 import (
    MeanStd, SimpleDataset, subsample_posterior,
    get_1d_rect_bounds, append_randomized, valmap, SwyftModel, persist_to_file, 
    equalize_tensors, SwyftTrainer, SampleStore, dictstoremap, DictDataset, SwyftModule, SwyftDataModule, RatioSamples,
    RatioSampleStore, file_cache
)

import lensing_model
from pytorch_lightning import loggers as pl_loggers

from swyft.lightning.components_v1 import RatioEstimatorGaussian1d as RatioEstimatorGaussian1d_v1
from swyft.lightning.components import RatioEstimatorGaussian1d, RatioEstimatorMLP1d

In [3]:
WORKS = False

## Problem-specific analysis components

In [4]:
KNN = 3
SIGMA = 0.02
NPIX_SRC = NPIX_IMG = 100

class Model(SwyftModel):
    def slow(self, pars):
        torch.cuda.set_device(0)
        torch.set_default_tensor_type(torch.cuda.FloatTensor)    
        x, y, phi, q, r_ein, slope = pars['z_lens']
        slope = 2.0
        src_image = pars['z_src'].cuda()
        img, coords = lensing_model.image_generator(x, y, phi, q, r_ein, slope, src_image)
        X, Y, Xsrc, Ysrc = coords
        kNN_idx = lensing_model.get_kNN_idx(X/5, -Y/5, Xsrc, Ysrc, k = KNN)  # TODO: Need to sort out strange 1/5 and -1/5 factors
        torch.set_default_tensor_type(torch.FloatTensor)
        return SampleStore(mu = img.cpu(), kNN_idx = kNN_idx.cpu(), X = X.cpu(), Y = Y.cpu(), Xsrc = Xsrc.cpu(), Ysrc = Ysrc.cpu())
    
    def fast(self, d):
        img = d['mu'] + torch.randn_like(d['mu'])*SIGMA
        return SampleStore(img=img)
    
    def prior(self, N, bounds = None):
        src_samples = self.prior_src(N, bounds = bounds)
        lens_samples = self.prior_lens(N, bounds = bounds)
        return SampleStore(**src_samples, **lens_samples)
    
    # Draw from source prior
    def prior_src(self, N, bounds = None):
        if bounds is None or 'z_src' not in bounds:
            R = lensing_model.RandomSource()
            z_src = torch.stack([R().cpu() for _ in range(N)])
        else:
            n = 3
            l, h = bounds['z_src'].low, bounds['z_src'].high
            R = lensing_model.RandomSource()
            z_src = []
            for _ in range(N):
                rnd = sum([R().cpu()-R().cpu() for _ in range(n)])
                rnd -= rnd.min()
                rnd /= rnd.max()
                z_src.append(l+rnd*h)
            z_src = torch.stack(z_src)
        return SampleStore(z_src=z_src)

    def prior_lens(self, N, bounds = None):
        if bounds is not None:
            low = bounds['z_lens'].low
            high = bounds['z_lens'].high
        else:
            low =  np.array([-0.2, -0.2, 0, 0.2, 1.0, 1.5])
            high = np.array([0.2, 0.2, 1.5, 0.9, 2.0, 2.5])
        draw = np.array([np.random.uniform(low=low, high=high) for _ in range(N)])
        return SampleStore(z_lens = torch.tensor(draw).float())

In [5]:
m = Model()

In [6]:
#s_targets = m.sample(10)
#torch.save(s_targets, "test_targets.pt")
s_targets = torch.load("test_targets.pt")

In [7]:
# Works

class LensNetwork(SwyftModule):
    def __init__(self):
        super().__init__()
        self.online_z_score = swyft.networks.OnlineDictStandardizingLayer(dict(img = (NPIX_IMG, NPIX_IMG)))
        self.CNN = torch.nn.Sequential(
            torch.nn.Conv2d(1, 10, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(10, 20, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(20, 40, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(128),
            torch.nn.ReLU(),
            torch.nn.LazyLinear(256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 16),
        )
        if WORKS:
            self.ptrans = swyft.networks.ParameterTransform(
                6, [(0,), (1,), (2,), (3,), (4,), (5,)], online_z_score=False
            )
            n_marginals, n_block_parameters = self.ptrans.marginal_block_shape
            n_observation_features = 16
            self.classifier = swyft.networks.MarginalClassifier(
                n_marginals,
                n_observation_features + n_block_parameters,
                hidden_features=256,
                dropout_probability = 0.1,
                num_blocks=3,
            )
        else:
            self.c = RatioEstimatorMLP1d(16, 6)
        
    def forward(self, x, z):
        # Digesting x
        x = dict(img = x['img'])
        x = self.online_z_score(x)['img']
        x = self.CNN(x.unsqueeze(1)).squeeze(1)
        
        if WORKS:        
            # Digesting z
            z = self.ptrans(z['z_lens'])

            # Combine!
            x, z = equalize_tensors(x, z)
            ratios = self.classifier(x, z)
            w = RatioSamples(z.squeeze(-1), ratios)
            return dict(z_lens=w)
        else:
            out = self.c(x, z['z_lens'])
            return dict(z_lens = out)

class SourceNetwork(SwyftModule):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(10, 10)
        if WORKS:
            self.reg1d = RatioEstimatorGaussian1d_v1(momentum = 0.1)
        else:
            self.reg1d = RatioEstimatorGaussian1d(momentum = 0.1)
        self.L = torch.nn.Linear(NPIX_SRC**2, NPIX_SRC**2)
        
    def get_img_rec(self, x):
        x_img = x['img']
        x_kNN_idx = x['kNN_idx']
        x_src_rec = lensing_model.deproject_idx(x_img, x_kNN_idx)[:,:,:,:].mean(dim=1)
        x_src_rec = self.L(x_src_rec.view(-1, NPIX_SRC*NPIX_SRC)).view(-1, NPIX_SRC, NPIX_SRC)*0 + x_src_rec
        return x_src_rec
    
    def forward(self, x, z):
        x_img_rec = self.get_img_rec(x)
        z_src = z['z_src']
        x_img_rec, z_src = equalize_tensors(x_img_rec, z_src)
        if WORKS:
            ratios = self.reg1d(x_img_rec, z_src)
            w = RatioSamples(ratios[...,1], ratios[...,0])
            return dict(z_src = w)
        else:
            w = self.reg1d(x_img_rec, z_src)
            return dict(z_src = w)

In [8]:
Ntrain1, R1, ME = 500, 1, 10
TARGET = 3
tag = 'V03'
INFER_SOURCE = True

In [9]:
bounds = None
results = []
s0 = s_targets[TARGET]

#tbl = pl_loggers.TensorBoardLogger("lightning_logs", name = 'lensing_%s'%tag)#, default_hp_metric=True)
# s1: img, lens, src ~ p(img|lens, src)p(lens)p(src)
s1 = file_cache(lambda: m.sample(Ntrain1, bounds = bounds), './a_train_data.pt')

# r1: p(z_lens|img)/p(z_lens)
r1 = LensNetwork()

# d1: split img vs z_lens
# TODO: Specify x_keys = ['img'], z_keys=['z_lens']
d1 = SwyftDataModule(s1, model = m, batch_size = 128)

# Train r1 with d1
t1 = SwyftTrainer(accelerator = 'gpu', gpus=1, max_epochs = ME)
t1.fit(r1, d1)
t1.test(r1, d1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type                         | Params
----------------------------------------------------------------
0 | online_z_score | OnlineDictStandardizingLayer | 0     
1 | CNN            | Sequential                   | 13.3 K
2 | c              | RatioEstimatorMLP1d          | 110 K 
----------------------------------------------------------------
123 K     Trainable params
0         Non-trainable params
123 K     Total params
0.494     Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': 7.863509654998779, 'hp/KL-div': -1.4944143295288086}
--------------------------------------------------------------------------------


[{'hp/JS-div': 7.863509654998779, 'hp/KL-div': -1.4944143295288086}]