In [1]:
%load_ext autoreload
%autoreload 2

In [10]:
import os
from itertools import product

import xarray as xr
import pandas as pd
import numpy as np
import larch.numba as lx
from larch import P, X

from numpy.random import default_rng

In [6]:
def one_based(n):
    return pd.RangeIndex(1, n + 1)

def from_numpy(
    numpy_data,
    name,
    index_names=("otaz", "dtaz"),
    indexes="one-based",
    renames=None,
):
    arrays = {name: numpy_data}    
    d = {
        "dims": index_names,
        "data_vars": {name: {"dims": index_names, "data": numpy_data}},
    }
    if indexes == "one-based":
        indexes = {
            index_names[0]: one_based(numpy_data.shape[0]),
            index_names[1]: one_based(numpy_data.shape[1]),
        }
    if indexes is not None:
        d["coords"] = {
            index_name: {"dims": index_name, "data": index}
            for index_name, index in indexes.items()
        }
    return xr.Dataset.from_dict(d)

In [3]:
num_destinations = 2
num_purposes = 3

# Create data

In [14]:
num_alternatives = num_destinations * num_purposes

In [47]:
choice_mapping = {(dest, purpose): idx + 1 for idx, (dest, purpose) in enumerate(product(range(1, num_destinations + 1), range(1, num_purposes + 1)))}
display(choice_mapping)

{(1, 1): 1, (1, 2): 2, (1, 3): 3, (2, 1): 4, (2, 2): 5, (2, 3): 6}

In [48]:
alternatives = list(choice_mapping.values())

## attractions

In [49]:
attrs = pd.DataFrame(
    [(dest, purpose, altid) for (dest, purpose), altid in choice_mapping.items()], 
    columns=["destination", "purpose", "altid"]
)

attrs["attractions"] = np.random.randint(1, 100, size=len(attrs))

attrs = attrs.set_index("altid")

In [50]:
display(attrs)

Unnamed: 0_level_0,destination,purpose,attractions
altid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,1,1,77
2,1,2,29
3,1,3,57
4,2,1,82
5,2,2,38
6,2,3,44


## skims

In [51]:
skims = np.zeros((num_destinations, num_destinations))

In [58]:
skims = from_numpy(skims, "distance")

In [59]:
display(skims)

## observations

In [70]:
obs = pd.DataFrame({
    "caseid": [1, 2, 3],
    "origin": [1, 2, 2],
    "chosen": [4, 2, 3],
}).set_index("caseid")

# Larch stuff

In [71]:
tree = lx.DataTree(
    obs=lx.Dataset.construct(obs, caseid="caseid", alts=alternatives),
    attr=attrs,
    skims=skims,
    relationships=(
        "obs._altid_ @ attr.altid",
        "obs.origin @ skims.otaz",
        "attr.destination @ skims.dtaz",
    ),
)

In [72]:
m = lx.Model(datatree=tree)
m.title = "blah"

m.quantity_ca = P("zero") * X("attractions")
m.quantity_scale = P.Theta

m.utility_ca = P.distance * X.distance

m.choice_co_code = "chosen"

# m.availability_var = "attr.attractions > 0"

In [73]:
for destination in range(num_destinations):
    m.graph.new_node(parameter='MuDest', children=[choice_mapping[(destination + 1, purpose + 1)] for purpose in range(num_purposes)], name=f"dest_{destination}")

In [74]:
m.graph

In [75]:
m.lock_values(
    MuDest=1,
    zero=0,
    Theta=1.,
    #distance=-10.0
)
# m.set_cap(10)

In [76]:
m.loglike()

hello from ed
[6 6 6 7 7 7 8 8]
[0 1 2 3 4 5 6 7]
hello from ed
[6 6 6 7 7 7 8 8]
[0 1 2 3 4 5 6 7]
hello from ed
[6 6 6 7 7 7 8 8]
[0 1 2 3 4 5 6 7]


-5.552814167606483