In [2]:
# %load init.py
#@title Import & Utilities
from __future__ import annotations

class AutoImportError(ImportError):
    pass

class auto(object):
    registry: ClassVar[Dict[str, Tuple[str, ...]]] = {}

    @classmethod
    def register(
        cls,
        import_name: str,
        package_name: Optional[str]=None,
        *extra_package_names: List[str],
    ):
        if package_name is None:
            package_name = import_name

        cls.registry[import_name] = (
            package_name,
            *extra_package_names,
        )

    def __getattr__(self, import_name: str):
        import subprocess, importlib, sys

        try:
            return object.__getattribute__(self, import_name)
        except AttributeError:
            pass

        module = None
        try:
            module = importlib.import_module(import_name)
        except ImportError as e:
            package_names = self.registry[import_name]

            process = subprocess.run([
                sys.executable,
                '-m', 'pip',
                'install',
                *package_names,
            ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

            if process.returncode != 0:
                raise AutoImportError(f"Failed to pip install {package_names!r}\n\n{process.stdout.read()}") from e

            try:
                module = importlib.import_module(import_name)
            except ImportError as e:
                raise AutoImportError(f'Import failed a second time, even after a pip install') from e

        assert module is not None
        # print(f'setattr({self!r}, {import_name!r}, {module!r})')
        setattr(self, import_name, module)
        return module

auto.register('np', 'numpy')
auto.register('tqdm')
auto.register('more_itertools', 'more-itertools')
auto.register('torch')
auto.register('peft')
auto.register('guidance')
auto.register('langchain')
auto.register('diffusers')

auto.register('transformers', None, 'transformers', 'accelerate', 'datasets', 'tokenizers', 'evaluate', 'huggingface_hub', 'torch')
auto.register('accelerate', None, 'transformers', 'accelerate', 'datasets', 'tokenizers', 'evaluate', 'huggingface_hub', 'torch')
auto.register('datasets', None, 'transformers', 'accelerate', 'datasets', 'tokenizers', 'evaluate', 'huggingface_hub', 'torch')
auto.register('tokenizers', None, 'transformers', 'accelerate', 'datasets', 'tokenizers', 'evaluate', 'huggingface_hub', 'torch')
auto.register('evaluate', None, 'transformers', 'accelerate', 'datasets', 'tokenizers', 'evaluate', 'huggingface_hub', 'torch')
auto.register('huggingface_hub', None, 'transformers', 'accelerate', 'datasets', 'tokenizers', 'evaluate', 'huggingface_hub', 'torch')

auto = auto()


def doctest(func=None, /, verbose=False, sterile=False):
    def wrapper(func):
        # Thanks https://stackoverflow.com/a/49659927
        import doctest, copy

        # I need this to error out on failure; the default one doesn't.
        def run_docstring_examples(f, globs, verbose=False, name="NoName", compileflags=None, optionflags=0):
            finder = doctest.DocTestFinder(verbose=verbose, recurse=False)
            runner = doctest.DocTestRunner(verbose=verbose, optionflags=optionflags)
            for test in finder.find(func, name, globs=globs):
                runner.run(test, compileflags=compileflags)
            assert runner.failures == 0

        name = func.__name__

        if sterile:
            globs = {}
        else:
            globs = copy.copy(globals())
        globs[name] = func
        run_docstring_examples(func, globs, verbose=verbose, name=name)
        return func

    if func is not None:
        return wrapper(func)
    else:
        return wrapper

try:
    g
except NameError:
    g = {}

try:
    f
except NameError:
    f = {}

def run(func=None, /, name=None, cond=True, splat=False, after=None, scope=None):
    def wrapper(func, /, *, name=name, cond=cond):
        import inspect

        if callable(cond):
            cond = cond()

        if not cond:
            return None

        if name is None:
            name = func.__name__
            
        f[name] = func

        args = []
        for key, parameter in inspect.signature(func).parameters.items():
            if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
                keys = [key]
                if scope is not None:
                    keys.insert(0, f'{scope}__{key}')
                
                for key in keys:
                    try:
                        value = g[key]
                    except KeyError:
                        continue
                    else:
                        args.append(value)
                        break
                else:
                    raise KeyError(f'None of {keys=!r} found in g')

        ret = func(*args)

        if callable(after):
            after(ret)

        if splat:
            it = ret.items()
        else:
            it = [(name, ret)]

        for name, ret in it:
            if scope is not None:
                name = f'{scope}__{name}'

            g[name] = ret

        return None

    if func is not None:
        return wrapper(func)
    else:
        return wrapper

@auto.IPython.core.magic.register_line_magic
@auto.IPython.core.magic.register_cell_magic
def source(magic_line, magic_cell=None):
    import os, subprocess, shlex

    if magic_cell is None or magic_cell == '':
        before = os.environ.copy()

        process = subprocess.run([
            'bash', '-c', f'source {magic_line}; export',
        ], capture_output=True, text=True)

        after = {}
        for line in process.stdout.split('\n'):
            if line == '': continue
            parts = shlex.split(line)
            assert parts[0] == 'declare', f'{line=!r}'
            assert parts[1] == '-x', f'{line=!r}'
            if '=' not in parts[2]: continue
            name, value = parts[2].split('=', 1)

            if before.get(name, None) == value: continue
            after[name] = value

        magic_cell = f'%%source {magic_line}\n'
        magic_cell += f'os.environ |= {{\n'
        for name, value in after.items():
            magic_cell += f'  {name!r}: '
            if ':' in value:
                magic_cell += f'":".join([\n'
                for value in value.split(':'):
                    magic_cell += f'    {value!r},\n'
                magic_cell += f'  ]),\n'
            else:
                magic_cell += f' {value!r},\n'
        magic_cell += f'}}\n'

        get_ipython().set_next_input(magic_cell, replace=True)

    get_ipython().run_cell(magic_cell)

In [3]:
@run(cond='cache' not in g)
def cache():
    return {}

@run
def cache(cache, /):
    def load() -> Dict:
        ret = {}
        try:
            with open(path, 'rb') as f:
                ret = auto.pickle.load(f)
                print(f'Read {f.tell():,d} bytes from {path}')
        except Exception as e:
            print('load failed:')
            auto.traceback.print_exception(e)
        
        return ret
    
    def merge(mutable: Dict, constant: Dict):
        try:
            for k, v in constant.items():
                if k in mutable:
                    continue

                mutable[k] = v
        except Exception as e:
            print('merge failed:')
            auto.traceback.print_exception(e)
    
    def dump(d: Dict):
        try:
            with open(path, 'wb') as f:
                auto.pickle.dump(d, f)

                print(f'Wrote {f.tell():,d} bytes to {path}')
        except Exception as e:
            print('dump failed:')
            auto.traceback.print_exception(e)
        
    path = auto.pathlib.Path.cwd() / 'tmp' / 'Sunrise Demo Cache.pickle'
    path.parent.mkdir(parents=False, exist_ok=True)
    
    if path.exists():
        merge(cache, load())
    
    dump(cache)
    
    return cache

Read 14,473,750 bytes from /home/thobson2/src/Sunrise-Demo/tmp/Sunrise Demo Cache.pickle
Wrote 14,473,750 bytes to /home/thobson2/src/Sunrise-Demo/tmp/Sunrise Demo Cache.pickle


In [4]:
@run
def fetch(cache, /):
    def fetch(url: str, *, tqdm=None) -> bytes:
        key = url
        if key not in cache:
            if tqdm is not None:
                tqdm.set_description(f'Cache Miss: {url}')

            with auto.requests.get(url) as r:
                cache[key] = r.content
            
            auto.time.sleep(1)
        
        else:
            if tqdm is not None:
                tqdm.set_description('Cache Hit: {url}')
        
        return cache[key]

    return fetch

In [5]:
Degree = auto.typing.NewType('Degree', float)
Radian = auto.typing.NewType('Radian', float)
Meter = auto.typing.NewType('Meter', float)
Kilometer = auto.typing.NewType('Kilometer', float)

In [32]:
@auto.dataclasses.dataclass(eq=True, order=True, frozen=True)
class Tile:
    zoom: int
    x: int
    y: int

    @auto.functools.singledispatch
    def from_(*args, **kwargs) -> Tile:
        raise NotImplementedError

@auto.dataclasses.dataclass(eq=True, order=True, frozen=True)
class Location:
    lat: Degree
    lng: Degree
    alt: Optional[Kilometer] = auto.dataclasses.field(default=None)
    
    @auto.functools.singledispatch
    def from_(*args, **kwargs) -> Location:
        raise NotImplementedError

@auto.dataclasses.dataclass(eq=True, order=True, frozen=True)
class Position:
    x: Kilometer
    y: Kilometer
    z: Kilometer
    
    @auto.functools.singledispatch
    def from_(*args, **kwargs) -> Position:
        raise NotImplementedError

In [34]:
@Tile.from_.register(Location)
@doctest
def convert_location_to_tile(location: Location, /, zoom: int, *, math=auto.math) -> Tile:
    r"""
    
    >>> convert_location_to_tile(Location(lat=35.6, lng=-83.52), zoom=11)
    Tile(zoom=11, x=548, y=807)
    >>> convert_location_to_tile(Location(lat=35.0, lng=-84.5), zoom=15)
    Tile(zoom=15, x=8692, y=12979)
    >>> convert_location_to_tile(Location(lat=36.25, lng=-82.5), zoom=15)
    Tile(zoom=15, x=8874, y=12839)
    
    """
    # Thanks https://stackoverflow.com/a/72476578
    
    #> lat_rad = math.radians(lat_deg)
    lat_rad = math.radians(location.lat)
    
    #> n = 2.0 ** zoom
    n = 2.0 ** zoom
    
    #> xtile = int((lon_deg + 180.0) / 360.0 * n)
    x = int((location.lng + 180.0) / 360.0 * n)
    
    #> ytile = int((1.0 - math.asinh(math.tan(lat_rad)) / math.pi) / 2.0 * n)
    y = int((1.0 - math.asinh(math.tan(lat_rad)) / math.pi) / 2.0 * n)
    
    #> return (xtile, ytile)
    return Tile(
        zoom=zoom,
        x=x,
        y=y,
    )

In [31]:
Tile.from_(Location(lat=35.6, lng=-83.52), zoom=11)

Tile(zoom=11, x=548, y=807)

In [38]:
@Location.from_.register(Tile)
@doctest
def convert_tile_to_location(tile: Tile, /, where: str, *, math=auto.math) -> Location:
    r"""

    >>> convert_tile_to_location(Tile(zoom=15, x=8739, y=12925), where='nw')
    Location(lat=35.48751102385376, lng=-83.990478515625, alt=None)

    """
    
    if where == 'nw':
        zoom = tile.zoom
        x = tile.x
        y = tile.y
    
    elif where == 'se':
        zoom = tile.zoom
        x = tile.x + 1
        y = tile.y + 1
    
    elif where == 'center':
        zoom = tile.zoom + 1
        x = 2 * tile.x + 1
        y = 2 * tile.y + 1
    
    else:
        raise NotImplementedError()
    
    # Thanks https://gis.stackexchange.com/a/133535
    
    #> n = 2 ^ zoom
    n = 2 ** tile.zoom
    
    #> lon_deg = xtile / n * 360.0 - 180.0
    lon_deg = tile.x / n * 360.0 - 180.0
    
    #> lat_rad = arctan(sinh(π * (1 - 2 * ytile / n)))
    lat_rad = math.atan(math.sinh(math.pi * (1.0 - 2.0 * y / n)))
    
    #> lat_deg = lat_rad * 180.0 / π
    lat_deg = lat_rad * 180.0 / math.pi
    
    return Location(lat=lat_deg, lng=lon_deg)

In [51]:
@run
def __mirror_tile_images(fetch, /):
    sw = Tile.from_(Location(lat=35.0, lng=-84.5), zoom=11)
    ne = Tile.from_(Location(lat=36.25, lng=-82.5), zoom=11)
    assert sw.x < ne.x
    assert sw.y > ne.y
    
    def mirror(*, path: Path, template: str):
        def mirror(*, path: Path, url: str, tqdm):
            if path.exists():
                return
            
            path.parent.mkdir(parents=True, exist_ok=True)
            path.write_bytes(fetch(url))

        template = auto.string.Template(template)
        
        y0 = min(sw.y, ne.y)
        y1 = max(sw.y, ne.y)
        
        x0 = min(sw.x, ne.x)
        x1 = max(sw.x, ne.x)
        
        z0 = min(sw.zoom, ne.zoom)
        z1 = max(sw.zoom, ne.zoom)
        
        for z, y, x in (tqdm := auto.tqdm.tqdm(auto.itertools.product(
            range(z0, z1 + 1),
            range(y0, y1 + 1),
            range(x0, x1 + 1),
        ), total=(z1-z0+1)*(x1-x0+1)*(y1-y0+1))):
            mirror(
                path=path / f'z{z}x{x}y{y}.png',
                url=template.substitute(
                    z=str(z),
                    x=str(x),
                    y=str(y),
                ),
                tqdm=tqdm,
            )
    
    mirror(
        path=auto.pathlib.Path.cwd() / 'data' / 'background',
        template=(
            r"""https://atlas-stg.geoplatform.gov/styles/v1/atlas-user/ck58pyquo009v01p99xebegr9/tiles/256/${z}/${x}/${y}?access_token=pk.eyJ1IjoiYXRsYXMtdXNlciIsImEiOiJjazFmdGx2bjQwMDAwMG5wZmYwbmJwbmE2In0.lWXK2UexpXuyVitesLdwUg"""
        ),
    )
    
    mirror(
        path=auto.pathlib.Path.cwd() / 'data' / 'terrain',
        template=(
            r"""https://api.mapbox.com/v4/mapbox.terrain-rgb/${z}/${x}/${y}.pngraw?access_token=pk.eyJ1IjoidGhvYnNvbjIiLCJhIjoiY2oxZmdqbnQzMDBpbjJxczR6dWoyemUxNiJ9.SEBHSdHLP_lZGD43r-_IDQ"""
        ),
    )
    
    mirror(
        path=auto.pathlib.Path.cwd() / 'data' / 'observation',
        template=(
            r"""https://atlas-stg.geoplatform.gov:443/v4/atlas-user.0013547_pink/${z}/${x}/${y}.png?access_token=pk.eyJ1IjoiYXRsYXMtdXNlciIsImEiOiJjazFmdGx2bjQwMDAwMG5wZmYwbmJwbmE2In0.lWXK2UexpXuyVitesLdwUg"""
        ),
    )


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [02:30<00:00,  1.25s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [02:45<00:00,  1.38s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:55<00:00,  1.03it/s]


In [56]:
!ls data/{background,terrain,observation} | wc -l

433


## Modeling

In [None]:
@auto.dataclasses.dataclass(eq=True, frozen=True)
class BackgroundTile:
    pass

@auto.dataclasses.dataclass(eq=True, frozen=True)
class Background:
    tiles: list[BackgroundTile]

@run
def background() -> Background:
    background = Background(
        tiles=[],
    )
    
    

In [71]:
@auto.dataclasses.dataclass(eq=True, frozen=True)
class Terrain:
    @auto.dataclasses.dataclass(eq=True, frozen=True)
    class Altitude:
        tile: Tile
        data: auto.np.ndarray[tuple[int, int, Literal[1]], auto.np.float32]
    
    altitudes: dict[Tile, Array]

@run
def terrain() -> Terrain:
    auto.PIL; import PIL.Image
    
    def tiles(*, sw: Tile, ne: Tile) -> Iterator[Tile]:
        assert sw.x < ne.x
        assert sw.y > ne.y
        assert sw.zoom == ne.zoom
        
        z0 = min(sw.zoom, ne.zoom)
        z1 = max(sw.zoom, ne.zoom)
        
        x0 = min(sw.x, ne.x)
        x1 = max(sw.x, ne.x)
        
        y0 = min(sw.y, ne.y)
        y1 = max(sw.y, ne.y)
        
        for z, x, y in auto.itertools.product(
            range(z0, z1 + 1),
            range(x0, x1 + 1),
            range(y0, y1 + 1),
        ):
            yield Tile(zoom=z, x=x, y=y)
    
    terrain = Terrain(
        altitudes={},
    )
    
    for tile in tiles(
        sw=Tile.from_(Location(lat=35.0, lng=-84.5), zoom=11),
        ne=Tile.from_(Location(lat=36.25, lng=-82.5), zoom=11),
    ):
        path = auto.pathlib.Path.cwd() / 'data' / 'terrain' / f'z{tile.zoom}x{tile.x}y{tile.y}.png'
        with open(path, 'rb') as f:
            image = auto.PIL.Image.open(f)
            image.load()
        
        image = image.convert('RGB')
        rgb = auto.np.array(image)
        assert rgb.shape == (256, 256, 3), \
            f"Wrong shape: {rgb.shape=!r}"
        
        # Thanks https://docs.mapbox.com/data/tilesets/guides/access-elevation-data/#decode-data
        #> elevation = -10000 + (({R} * 256 * 256 + {G} * 256 + {B}) * 0.1)
        
        altitude = auto.np.zeros(
            shape=(rgb.shape[0], rgb.shape[1], 1),
            dtype=auto.np.float32,
        )
        assert altitude.shape == (256, 256, 1), \
            f"Wrong shape: {altitude.shape=!r}"
        
        altitude[:, :, 0] += rgb[:, :, 0] * (1 * 256 * 256)
        altitude[:, :, 0] += rgb[:, :, 1] * (1 * 256)
        altitude[:, :, 0] += rgb[:, :, 2] * (1)
        altitude[:, :, 0] *= 0.1
        altitude[:, :, 0] -= 10000
    
        terrain.altitudes[tile] = Terrain.Altitude(
            tile=tile,
            data=altitude,
        )
    
    return terrain