In [3]:
import torchlensmaker as tlm
import torch

import json

from torchlensmaker.new_light_sources.light_sources_elements import GenericLightSource, Object2D, ObjectAtInfinity2D, PointSource2D, RaySource2D
from torchlensmaker.new_light_sources.source_geometry_elements import ObjectGeometry2D

from torchlensmaker.new_sampling.sampling_elements import LinspaceSampler1D
from torchlensmaker.new_material.material_elements import NonDispersiveMaterial


def describe(name, t):
    print(name)
    print(t.shape)
    print(t.min(), t.max())
    print()

lgeneric = GenericLightSource(
        sampler_pupil = LinspaceSampler1D(6),
        sampler_field= LinspaceSampler1D(6),
        sampler_wavelength= LinspaceSampler1D(6),
        material = NonDispersiveMaterial(1.5),
        geometry = ObjectGeometry2D(beam_angular_size=20, object_diameter=5, wavelength=(600, 800)),
)


optics = tlm.Sequential(
    tlm.Rotate2D(15),
    # Object2D(beam_angular_size=5, object_diameter=5),
    #  ObjectAtInfinity2D(beam_diameter=5, angular_size=10)
    #PointSource2D(beam_angular_size=30),
    RaySource2D(),
    tlm.Gap(10),
    tlm.RefractiveSurface(tlm.Sphere(15, 50), material="BK7"),
    tlm.Gap(50),
    tlm.Aperture(50)
)

print(optics)

data = tlm.default_input(dim=2, dtype=torch.float64, sampling={"base": 5, "object": 3, "wavelength": 3})
outputs = optics(data)

scene = tlm.show2d(optics, return_scene=True)
# json.dump(scene, open("testnb.json", "w"))


Sequential(
  (0): Rotate2D()
  (1): RaySource2D(
    (pupil_sampler): ZeroSampler1D()
    (field_sampler): ZeroSampler1D()
    (sampler_wavelength): LinspaceSampler1D()
    (material): NonDispersiveMaterial()
    (geometry): ObjectAtInfinityGeometry2D()
  )
  (2): Gap(
    (mixed_dim): MixedDim(
      (module_2d): Translate2D()
      (module_3d): Translate3D()
    )
  )
  (3): RefractiveSurface(
    (collision_surface): CollisionSurface()
  )
  (4): Gap(
    (mixed_dim): MixedDim(
      (module_2d): Translate2D()
      (module_3d): Translate3D()
    )
  )
  (5): Aperture(
    (collision_surface): CollisionSurface()
  )
)
