## The beginning of this notebook will begin as `fabrication` does

In [1]:
import csv
import numpy as np
from spectra import Spectra, Pigment
import matplotlib.pyplot as plt
import numpy.typing as npt
import torch as th

In [2]:
from metamers import Cone, Observer
from models import gaussian, Neaugebauer

In [3]:
primary_fns = [
    "000",
    "001",
    "010",
    "100",
    "011",
    "110",
    "101",
    "111",
]

In [4]:
spectra_primaries_dict = {}
wavelengths = np.arange(400, 701, 10)

for fn in primary_fns:
    with open(f'PrintColors/{fn}.csv') as csvf:
        spamreader = csv.reader(csvf, delimiter=';')
        for i, row in enumerate(spamreader):
            if i == 4:
                color_data = np.array(row[33:],dtype=float)
                spectra = Spectra(data=color_data, wavelengths=wavelengths)
                spectra_primaries_dict[fn] = spectra

In [8]:
def remove_trailing_nans(arr):
    mask = np.any(np.isnan(arr), axis=1)
    idx = np.where(mask)[0]
    if idx.size > 0:
        last_valid_idx = np.where(~mask)[0][-1]
        return arr[:last_valid_idx + 1]
    return arr

cone_data = np.genfromtxt('linss2_10e_1.csv', delimiter=',')

l_cone = Cone(cone_data[:311, [0, 1]])
m_cone = Cone(cone_data[:311, [0, 2]])
s_cone = Cone(remove_trailing_nans(cone_data[:311, [0, 3]]))



shift = 15
r = [(w,1e-4) for w in m_cone.wavelengths() if w < m_cone.wavelengths()[0]+shift] + \
        [(w+shift,v) for (w,v) in m_cone.reflectance if w+shift <= m_cone.wavelengths()[-1]]
q_cone = Cone(reflectance=np.array(r))

trichromat = Observer([s_cone, m_cone, l_cone], min_transition_size=5)
tetrachromat = Observer([s_cone, m_cone, q_cone, l_cone], min_transition_size=5)

In [10]:
s_cone.data()

array([9.54729e-03, 1.14794e-02, 1.37986e-02, 1.65746e-02, 1.98869e-02,
       2.38250e-02, 2.84877e-02, 3.39832e-02, 4.04274e-02, 4.79417e-02,
       5.66498e-02, 6.66757e-02, 7.81479e-02, 9.11925e-02, 1.05926e-01,
       1.22451e-01, 1.40844e-01, 1.61140e-01, 1.83325e-01, 2.07327e-01,
       2.33008e-01, 2.60183e-01, 2.88723e-01, 3.18512e-01, 3.49431e-01,
       3.81363e-01, 4.14141e-01, 4.47350e-01, 4.80439e-01, 5.12767e-01,
       5.43618e-01, 5.72399e-01, 5.99284e-01, 6.24786e-01, 6.49576e-01,
       6.74474e-01, 7.00186e-01, 7.26460e-01, 7.52726e-01, 7.78333e-01,
       8.02555e-01, 8.24818e-01, 8.45422e-01, 8.64961e-01, 8.84100e-01,
       9.03573e-01, 9.23844e-01, 9.44055e-01, 9.62920e-01, 9.79057e-01,
       9.91020e-01, 9.97765e-01, 9.99982e-01, 9.98861e-01, 9.95628e-01,
       9.91515e-01, 9.87377e-01, 9.82619e-01, 9.76301e-01, 9.67513e-01,
       9.55393e-01, 9.39499e-01, 9.20807e-01, 9.00592e-01, 8.80040e-01,
       8.60240e-01, 8.42031e-01, 8.25572e-01, 8.10860e-01, 7.979

In [None]:
cmy_neugebauer = Neaugebauer(spectra_primaries_dict)

In [None]:
trichromat_observe = th.from_numpy(trichromat.get_sensor_matrix(wavelengths)).to(th.float32)
trichromat_whitepoint = th.from_numpy(trichromat.get_whitepoint(wavelengths)).to(th.float32)
cmy_neugebauer.mix(th.tensor((0.3, 0.5 , 0.2), dtype=th.float32))

In [None]:
trichromat_whitepoint

In [None]:
def neugebauer_mix(percentages, n=50):
    percentages = np.array(percentages, dtype=float) 
    output = Spectra(wavelengths=wavelengths, data=np.zeros_like(wavelengths))
    for key, spectra in spectra_primaries_dict.items():
        binary_vector = np.array(list(map(int, key)))
        weight = np.prod(binary_vector * percentages + (1 - binary_vector) * (1 - percentages))
        output += weight * (spectra ** (1 / n))
    return (output ** n)

_tri_ink_gamut = {}
    
for c in np.arange(0, 1.1, 0.1):
    for m in np.arange(0, 1.1, 0.1):
        for y in np.arange(0, 1.1, 0.1):
            mixed_spectra = cmy_neugebauer.mix(th.tensor( (c, m , y), dtype=th.float32))
            mixed_spectra2 = neugebauer_mix((c,m,y))
            assert np.allclose(mixed_spectra.numpy(), mixed_spectra2.data())            
            tristimulus = th.matmul(trichromat_observe, mixed_spectra.T).squeeze() / trichromat_whitepoint.T

            _tri_ink_gamut[(int(100 * c), int(100 * m), int(100 *y))] = tristimulus

TRI_INK_GAMUT = np.array(list(_tri_ink_gamut.values()))

In [None]:
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import ConvexHull

In [None]:
def plot_pointcloud_hull(points, labels=('S','M','L'), fig=None, ax=None, alpha=1):
    # points in shape M x 3
    hull = ConvexHull(points)
    
    if fig is None or ax is None:  
        fig = plt.figure()
        ax = fig.add_subplot(111, projection="3d")
    ax.set_xlabel(labels[0])
    ax.set_ylabel(labels[1])
    ax.set_zlabel(labels[2])    
    
    for simplex in hull.simplices:
        s = simplex.astype(int)
        s = np.append(s, s[0]) 
        ax.plot(points[s,0], points[s,1], points[s,2], "k-", alpha=alpha)

    return fig, ax

%matplotlib notebook
plot_pointcloud_hull(TRI_INK_GAMUT)

In [None]:
from collections import defaultdict

def furthest_metamers(points, axis, stepsize=0.05):
    buckets = defaultdict(list)
    
    for idx, point in enumerate(points):
        key = tuple([int(x / stepsize) for i, x in enumerate(point) if i != axis])
        buckets[key].append(idx)

        
    max_dist = -1
    max_key = None
    
    for key, idxs in buckets.items():
        qs = [points[idx][axis] for idx in idxs]
        span_qs = max(qs) - min(qs)
        if span_qs > max_dist:
            max_dist = span_qs
            max_key = key

    
    thickest_indices = buckets[max_key]
    thickest_qs = [points[idx][axis] for idx in thickest_indices]
    (max_q, idx_max_q) = max(zip(thickest_qs, thickest_indices))
    (min_q, idx_min_q) = min(zip(thickest_qs, thickest_indices))

    print(idx_max_q, idx_min_q)
#     print(points[idx_max_q], points[idx_min_q])
    
    arr = np.abs(points[idx_min_q] - points[idx_max_q])
    return arr[axis] / max(arr[0:axis])

furthest_metamers(TRI_INK_GAMUT, 2)

TRI_GAMUT_POINTS = trichromat.get_full_colors().T

furthest_metamers(TRI_GAMUT_POINTS, 2)

TRI_GAMUT_POINTS[78859]

TRI_GAMUT_POINTS[394674]

In [None]:
[0.19782491, 0.49949561, 0.69155834]

## Backprop color matching

from models import ColorMatchingModel, NeugabauerMatchingModel
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

model = NeugabauerMatchingModel(3)
target_tristimulus = th.tensor([0.19782491, 0.49949561, 0.69155834])

optimizer = optim.SGD(model.parameters(), lr=0.05)
num_epochs = 100

for epoch in range(1, num_epochs + 1):
    weights = model()

    loss = cmy_neugebauer.metric(weights, target_tristimulus, trichromat_observe, trichromat_whitepoint)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if not epoch % 10:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

print(cmy_neugebauer.observe(model(),
                            trichromat_observe,
                            trichromat_whitepoint))

target_tristimulus = th.tensor([0.16047033, 0.45036489, 0.28414145])

model2 = NeugabauerMatchingModel(3)
optimizer = optim.SGD(model2.parameters(), lr=0.05)
num_epochs = 100

for epoch in range(1, num_epochs + 1):
    weights = model2()

    loss = cmy_neugebauer.metric(weights, target_tristimulus, trichromat_observe, trichromat_whitepoint)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if not epoch % 10:
        print(f"Epoch {epoch}, Loss: {loss.item()}")
        
print(model2())

## Dual Optimization

In [None]:
from models import grid_train

In [None]:
cmy_tri_params = grid_train(cmy_neugebauer, trichromat)

In [None]:
cmy_neugebauer.observe(cmy_tri_params[:3],
                            trichromat_observe,
                            trichromat_whitepoint)

In [None]:
cmy_neugebauer.observe(cmy_tri_params[3:],
                            trichromat_observe,
                            trichromat_whitepoint)

**dual optimiation seems to be a hit!**

In [None]:
neugebauer_mix(cmy_tri_params[:3].detach()).plot()
neugebauer_mix(cmy_tri_params[3:].detach() ).plot()
l_cone.plot(color='black', alpha=0.5)
m_cone.plot(color='black', alpha=0.5)
s_cone.plot(color='black', alpha=0.5)

## Dual Optimiation for Tetrachromacy

In [None]:
cmy_tetra_params = grid_train(cmy_neugebauer, tetrachromat, NUM_INITS=200)

Since there's probably no solution, it would be nice to see if we could find a class of solutions

## Approximating theoretical inks

In [None]:
def shift(pigment : Pigment,shiftAmount : int):
    r = [(w,pigment.reflectance[0][1]) for w in pigment.wavelengths() if w < pigment.wavelengths()[0]+shiftAmount] 
    r += [(w+shiftAmount,v) for (w,v) in pigment.reflectance if w+shiftAmount <= pigment.wavelengths()[-1] and w+shiftAmount >= pigment.wavelengths()[0]]
    r += [(w,pigment.reflectance[-1][1]) for w in pigment.wavelengths() if w > pigment.wavelengths()[-1]+shiftAmount]
    return Pigment(reflectance=np.array(r))

cyan = spectra_primaries_dict["100"]
magenta = spectra_primaries_dict["010"]
yellow = spectra_primaries_dict["001"]
green = shift(spectra_primaries_dict["100"],70)
lime = shift(spectra_primaries_dict["100"],90)
orange = shift(spectra_primaries_dict["010"],-50)
purple = shift(spectra_primaries_dict["010"],-20)
maroon = shift(spectra_primaries_dict["001"],90)
flaxen = shift(spectra_primaries_dict["001"],-10)
figc, axc = plt.subplots()
for spectra in [cyan, magenta, yellow, green, lime, purple]:
    spectra.plot(color=np.clip(spectra.to_rgb(),0,1), ax=axc)
    
s_cone.plot(ax=axc, color='black')
m_cone.plot(ax=axc, color='black')
q_cone.plot(ax=axc, color='black')
l_cone.plot(ax=axc, color='black')

In [None]:
def k_s_from_pigments(pigments):
    k_list = []
    s_list = []

    for pigment in pigments:
        if not isinstance(pigment, Pigment):
            pigment = Pigment(reflectance=pigment)
        k, s = pigment.get_k_s()
        k_list.append(k)
        s_list.append(s)

    k_matrix = np.column_stack(k_list)
    s_matrix = np.column_stack(s_list)

    return k_matrix, s_matrix

In [None]:
def km_mix(pigments , concentrations = None):
    K_matrix, S_matrix = k_s_from_pigments(pigments)
    wavelengths = pigments[0].wavelengths()

    if not concentrations:
        concentrations = np.array([1 / len(pigments)] * len(pigments) )
        
    K_mix = K_matrix @ concentrations
    S_mix = S_matrix @ concentrations / (len(pigments) ** 2) # varuns secret correction term
    
    k = np.column_stack((wavelengths, K_mix))
    s = np.column_stack((wavelengths, S_mix))
    
    return Pigment(k=k , s=s )

In [None]:
km_mix(
[
    spectra_primaries_dict["100"],
    spectra_primaries_dict["001"],


]
).plot(color='red')
spectra_primaries_dict["101"].plot()

plt.show()

In [None]:
km_mix(
[
    spectra_primaries_dict["100"],
    spectra_primaries_dict["010"],


]
).plot(color='red')
spectra_primaries_dict["110"].plot()

plt.show()

In [None]:
km_mix(
[
    spectra_primaries_dict["001"],
    spectra_primaries_dict["010"],

    
]
).plot(color='red')
spectra_primaries_dict["011"].plot()

plt.show()

In [None]:
km_mix(
[
    spectra_primaries_dict["001"],
    spectra_primaries_dict["010"],
    spectra_primaries_dict["100"],


]
).plot(color='red')
spectra_primaries_dict["111"].plot()

plt.show()

In [None]:
expanded_primaries_dict = {}
# Cyan, Magenta, Yellow, Green, Orange, Pink
cmygop_inks_dict = {
    "100000": spectra_primaries_dict["100"],
    "010000": spectra_primaries_dict["010"],
    "001000": spectra_primaries_dict["001"],
    "000100": green,
    "000010": orange, 
    "000001": purple
}

In [None]:
expanded_primaries_dict['000000'] = spectra_primaries_dict['000']
expanded_primaries_dict['101000'] = spectra_primaries_dict['101']
expanded_primaries_dict['110000'] = spectra_primaries_dict['110']
expanded_primaries_dict['011000'] = spectra_primaries_dict['011']
expanded_primaries_dict['111000'] = spectra_primaries_dict['111']

In [None]:
for i in range(1, 64): 
    binary_str = format(i, '06b')  # Convert integer to 6-bit binary string
    inks_to_mix = []
    
    for j, bit in enumerate(binary_str):
        if bit == '1':
            key = '0' * j + '1' + '0' * (5 - j)
            inks_to_mix.append(cmygop_inks_dict[key])
            
    if binary_str not in expanded_primaries_dict:
        mixed_ink = km_mix(inks_to_mix)
        expanded_primaries_dict[binary_str] = mixed_ink

In [None]:
expanded_primaries_dict['101001'].plot()

## Running tetrachromacy on expanded model

In [None]:
cmygop_neugebauer = Neaugebauer(expanded_primaries_dict)

In [None]:
cmygop_tetra_params = grid_train(cmygop_neugebauer, tetrachromat, NUM_INITS=5000, NUM_EPOCHS=100)

In [None]:
gibber

In [None]:
tetrachromat_observe = th.from_numpy(tetrachromat.get_sensor_matrix(wavelengths)).to(th.float32)
tetrachromat_whitepoint = th.from_numpy(tetrachromat.get_whitepoint(wavelengths)).to(th.float32)

In [None]:
spectra1 = 2 * Spectra(wavelengths=wavelengths,data=cmygop_neugebauer.mix(cmygop_tetra_params[:6]).squeeze())

In [None]:
spectra2 = 2 * Spectra(wavelengths=wavelengths,data=cmygop_neugebauer.mix(cmygop_tetra_params[6:]).squeeze())

In [None]:
spectra1.plot()
spectra2.plot()
l_cone.plot(alpha=0.3, color='black')
q_cone.plot(alpha=0.3, color='black')
m_cone.plot(alpha=0.3, color='black')
s_cone.plot(alpha=0.3, color='black')

plt.show()

In [None]:
tetrachromat.observe(spectra1)

In [None]:
tetrachromat.observe(specra2)

In [None]:
cmygop_neugebauer.dual_optimization_metric(cmygop_tetra_params, tetrachromat_observe, tetrachromat_whitepoint, 2)

In [None]:
best_cmygop[6:]

In [None]:
a = 4 * Spectra(data=cmygop_neugebauer.mix(best_cmygop[:6]).detach().numpy().flatten(), wavelengths=wavelengths)

In [None]:
b = 4 * Spectra(data=cmygop_neugebauer.mix(best_cmygop[6:]).detach().numpy().flatten(), wavelengths=wavelengths)

In [None]:
fig7, ax7 = plt.subplots()
a.plot(ax=ax7)
b.plot(ax=ax7)
l_cone.plot(ax=ax7,color='black', alpha=0.5)
q_cone.plot(ax=ax7,color='black', alpha=0.5)
m_cone.plot(ax=ax7,color='black', alpha=0.5)
s_cone.plot(ax=ax7,color='black', alpha=0.5)
plt.show()

In [None]:
a.plot()
b.plot()
plt.show()

kek what !

For a best result that sucks a lot! 

In [None]:
# Cyan, Magenta, Yellow, Lime, pink
cmylp_primaries_dict = {}
cmylp_inks_dict = {
    "10000": spectra_primaries_dict["100"],
    "01000": spectra_primaries_dict["010"],
    "00100": spectra_primaries_dict["001"],
    "00010": lime,
    "00001": purple, 
}

In [None]:
cmylp_primaries_dict['00000'] = spectra_primaries_dict['000']
cmylp_primaries_dict['10100'] = spectra_primaries_dict['101']
cmylp_primaries_dict['11000'] = spectra_primaries_dict['110']
cmylp_primaries_dict['01100'] = spectra_primaries_dict['011']
cmylp_primaries_dict['11100'] = spectra_primaries_dict['111']

In [None]:
for i in range(1, 2 ** 5): 
    binary_str = format(i, '05b')  
    inks_to_mix = []
    
    for j, bit in enumerate(binary_str):
        if bit == '1':
            key = '0' * j + '1' + '0' * (4 - j)
            inks_to_mix.append(cmylp_inks_dict[key])
            
    if binary_str not in cmylp_primaries_dict:
        mixed_ink = km_mix(inks_to_mix)
        cmylp_primaries_dict[binary_str] = mixed_ink

In [None]:
cmylp_neugebauer = Neaugebauer(cmylp_primaries_dict)

In [None]:
cmylp_tetra_params = grid_train(cmylp_neugebauer, tetrachromat, NUM_INITS=1000)

In [None]:
# Cyan, Magenta, Yellow, Lime, pink
cmyglp_primaries_dict = {}
cmyglp_inks_dict = {
    "100000": spectra_primaries_dict["100"],
    "010000": spectra_primaries_dict["010"],
    "001000": spectra_primaries_dict["001"],
    "000100": green,
    "000010": lime,
    "000001": purple, 
}

In [None]:
cmyglp_primaries_dict['000000'] = spectra_primaries_dict['000']
cmyglp_primaries_dict['101000'] = spectra_primaries_dict['101']
cmyglp_primaries_dict['110000'] = spectra_primaries_dict['110']
cmyglp_primaries_dict['011000'] = spectra_primaries_dict['011']
cmyglp_primaries_dict['111000'] = spectra_primaries_dict['111']

In [None]:
for i in range(1, 2 ** 6): 
    binary_str = format(i, '06b')  
    inks_to_mix = []
    
    for j, bit in enumerate(binary_str):
        if bit == '1':
            key = '0' * j + '1' + '0' * (5 - j)
            inks_to_mix.append(cmyglp_inks_dict[key])
            
    if binary_str not in cmyglp_primaries_dict:
        mixed_ink = km_mix(inks_to_mix)
        cmyglp_primaries_dict[binary_str] = mixed_ink

In [None]:
cmyglp_neugebauer = Neaugebauer(cmyglp_primaries_dict)

In [None]:
cmyglp_tetra_params = grid_train(cmyglp_neugebauer, tetrachromat, NUM_INITS=1000)

In [None]:
tetrachromat.observe(Spectra(wavelengths=wavelengths,
                             data=cmyglp_neugebauer.mix(cmyglp_tetra_params[:6]).squeeze() ) )

## Finding width of tetrachromat gamut

In [None]:
wavelengths5 = np.array(range(400, 701, 5))

In [None]:
len(wavelengths5)

In [None]:
from models import SigmoidModel, full_spectra_dual_optimization_metric
import torch.optim as optim

In [None]:
instance = SigmoidModel(th.rand(2 * len(wavelengths5)))

optimizer = optim.SGD(instance.parameters(), lr=0.01)
observer_matrix = th.from_numpy(tetrachromat.get_sensor_matrix(wavelengths5)).to(th.float32)
whitepoint = th.from_numpy(tetrachromat.get_whitepoint(wavelengths5)).to(th.float32)

epochs_values = []
loss_values = []


for epoch in range(1, 100000 + 1):
    spectras = instance()
    loss = full_spectra_dual_optimization_metric(
        spectras, observer_matrix, whitepoint, 2
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if not (epoch % 1000):
        ls = loss.item()
        print(f"After {epoch} epochs, loss is {ls}.")
        loss_values.append(ls)
        epochs_values.append(epoch)


plt.plot(epochs_values, loss_values, label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Curve")
plt.legend()
plt.grid(True)
plt.show()

instance()

In [None]:
met1 = Spectra(wavelengths=wavelengths5, data=instance().detach().numpy()[:61])
met2 = Spectra(wavelengths=wavelengths5, data=instance().detach().numpy()[61:])

In [None]:
met1.plot()
met2.plot()

In [None]:
print(tetrachromat.observe(met1))
print(tetrachromat.observe(met2))

In [None]:
tetrachromat.observe(met2) - tetrachromat.observe(met1)

In [None]:
wavelengths1 = np.array(
    range(400, 701, 1))

In [None]:
len(wavelengths1)

In [None]:
instance = SigmoidModel(th.rand(2 * len(wavelengths1)))

optimizer = optim.SGD(instance.parameters(), lr=0.01)
observer_matrix = th.from_numpy(tetrachromat.get_sensor_matrix(wavelengths1)).to(th.float32)
whitepoint = th.from_numpy(tetrachromat.get_whitepoint(wavelengths1)).to(th.float32)

epochs_values = []
loss_values = []


for epoch in range(1, 100000000 + 1):
    spectras = instance()
    loss = full_spectra_dual_optimization_metric(
        spectras, observer_matrix, whitepoint, 2
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if not (epoch % 100000):
        ls = loss.item()
        print(f"After {epoch} epochs, loss is {ls}.")
        loss_values.append(ls)
        epochs_values.append(epoch)


plt.plot(epochs_values, loss_values, label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Curve")
plt.legend()
plt.grid(True)
plt.show()

instance()

In [None]:
curr_result = th.tensor([0.7265, 0.5280, 0.5190, 0.5710, 0.6233, 0.7038, 0.6039, 0.5630, 0.6357,
        0.7293, 0.5183, 0.6764, 0.6076, 0.6172, 0.6995, 0.6078, 0.7032, 0.7100,
        0.5337, 0.5866, 0.5318, 0.7139, 0.5917, 0.6213, 0.7848, 0.8317, 0.8133,
        0.8181, 0.8507, 0.6854, 0.8392, 0.8669, 0.8649, 0.7565, 0.7646, 0.8730,
        0.8207, 0.7712, 0.8640, 0.8791, 0.8867, 0.8192, 0.8505, 0.8712, 0.8727,
        0.8796, 0.8765, 0.8930, 0.8721, 0.8320, 0.8485, 0.8729, 0.8530, 0.8784,
        0.8810, 0.8769, 0.8506, 0.8473, 0.8391, 0.8643, 0.8513, 0.8580, 0.8629,
        0.7985, 0.8226, 0.7310, 0.8256, 0.8425, 0.7516, 0.4244, 0.7042, 0.6450,
        0.3997, 0.6053, 0.5949, 0.6117, 0.5200, 0.4744, 0.6996, 0.4927, 0.6960,
        0.4572, 0.6599, 0.7082, 0.6182, 0.5596, 0.4117, 0.4281, 0.6713, 0.4941,
        0.4794, 0.3716, 0.3916, 0.5299, 0.3870, 0.4643, 0.3832, 0.5102, 0.3153,
        0.4406, 0.2465, 0.2598, 0.1917, 0.1920, 0.1532, 0.1917, 0.1774, 0.1726,
        0.1290, 0.1437, 0.1302, 0.1056, 0.0984, 0.0978, 0.1089, 0.0834, 0.0872,
        0.1071, 0.0988, 0.1001, 0.0932, 0.1155, 0.1213, 0.1255, 0.1126, 0.1331,
        0.1499, 0.1171, 0.1327, 0.2051, 0.2123, 0.1679, 0.2000, 0.3943, 0.3229,
        0.4864, 0.6695, 0.6501, 0.7343, 0.4834, 0.7432, 0.7698, 0.6546, 0.8178,
        0.7960, 0.8042, 0.8086, 0.8663, 0.8542, 0.8359, 0.8461, 0.8485, 0.8851,
        0.8720, 0.8927, 0.8863, 0.8753, 0.8858, 0.9081, 0.8845, 0.8814, 0.8831,
        0.8842, 0.8831, 0.8808, 0.9029, 0.8688, 0.8809, 0.8916, 0.8823, 0.9002,
        0.8651, 0.8991, 0.9005, 0.8898, 0.8980, 0.8872, 0.8754, 0.8893, 0.8672,
        0.8663, 0.8949, 0.8852, 0.8488, 0.8817, 0.8179, 0.8440, 0.8710, 0.8504,
        0.8457, 0.8006, 0.8380, 0.8173, 0.7885, 0.6473, 0.6440, 0.6244, 0.6939,
        0.6786, 0.5659, 0.5312, 0.6802, 0.4591, 0.4239, 0.3483, 0.3197, 0.4953,
        0.3673, 0.2841, 0.3350, 0.3934, 0.2838, 0.3406, 0.2464, 0.2274, 0.3158,
        0.2440, 0.2494, 0.2157, 0.2379, 0.2407, 0.1689, 0.2265, 0.2438, 0.1653,
        0.1756, 0.2089, 0.1990, 0.2730, 0.2269, 0.2897, 0.1781, 0.1816, 0.2002,
        0.2734, 0.2423, 0.2098, 0.2797, 0.2891, 0.2081, 0.3425, 0.2336, 0.2421,
        0.2406, 0.2921, 0.2554, 0.3593, 0.3911, 0.2550, 0.4341, 0.4432, 0.3839,
        0.2984, 0.3655, 0.3367, 0.5016, 0.4216, 0.5193, 0.4996, 0.5334, 0.3339,
        0.5019, 0.5415, 0.5271, 0.6127, 0.4647, 0.5145, 0.6335, 0.4188, 0.3990,
        0.6157, 0.4257, 0.4609, 0.4273, 0.5636, 0.4705, 0.5707, 0.4742, 0.6450,
        0.5856, 0.4879, 0.6470, 0.5899, 0.4938, 0.5971, 0.4619, 0.5866, 0.6764,
        0.5991, 0.6149, 0.7069, 0.5595, 0.6014, 0.5539, 0.5302, 0.6613, 0.5855,
        0.7189, 0.6215, 0.6251, 0.5018, 0.5728, 0.5631, 0.6600, 0.5612, 0.7132,
        0.7312, 0.6626, 0.6421, 0.7134, 0.6372, 0.7396, 0.6563, 0.5342, 0.6472,
        0.6704, 0.7254, 0.7441, 0.7411, 0.7693, 0.5968, 0.7302, 0.4958, 0.7315,
        0.6876, 0.5399, 0.7397, 0.8176, 0.7705, 0.8178, 0.8353, 0.4961, 0.4443,
        0.4011, 0.8370, 0.6546, 0.7396, 0.8383, 0.8185, 0.3017, 0.8363, 0.8686,
        0.8498, 0.8635, 0.8396, 0.8201, 0.6444, 0.8373, 0.6560, 0.8650, 0.8570,
        0.7507, 0.6702, 0.8261, 0.7329, 0.6084, 0.8315, 0.8210, 0.7686, 0.8162,
        0.8208, 0.7332, 0.6861, 0.8082, 0.4737, 0.8310, 0.7617, 0.7719, 0.6859,
        0.8231, 0.7934, 0.8129, 0.8192, 0.8251, 0.6930, 0.7636, 0.7775, 0.8349,
        0.7018, 0.6736, 0.6928, 0.7280, 0.8074, 0.7650, 0.6973, 0.7907, 0.7664,
        0.6122, 0.6817, 0.6940, 0.7357, 0.7644, 0.6724, 0.7686, 0.7546, 0.7855,
        0.8355, 0.8029, 0.7963, 0.7891, 0.8449, 0.8745, 0.8169, 0.8895, 0.8965,
        0.8806, 0.8806, 0.8896, 0.8904, 0.9220, 0.9236, 0.9259, 0.9261, 0.9295,
        0.9278, 0.9279, 0.9366, 0.9373, 0.9257, 0.9317, 0.9211, 0.9352, 0.9234,
        0.9229, 0.9176, 0.9262, 0.9152, 0.9219, 0.8960, 0.9017, 0.9105, 0.8835,
        0.8867, 0.8815, 0.7901, 0.8366, 0.7342, 0.8199, 0.7313, 0.7271, 0.7332,
        0.5365, 0.7802, 0.6780, 0.6314, 0.6048, 0.6756, 0.4444, 0.3231, 0.2242,
        0.2574, 0.4211, 0.3484, 0.2758, 0.1883, 0.1667, 0.1684, 0.1893, 0.1553,
        0.1342, 0.1325, 0.1947, 0.1460, 0.2078, 0.1720, 0.2781, 0.1561, 0.1521,
        0.1886, 0.2106, 0.1515, 0.1810, 0.1467, 0.1567, 0.1470, 0.1504, 0.1973,
        0.2489, 0.1476, 0.1990, 0.1494, 0.1682, 0.2943, 0.2307, 0.2774, 0.3212,
        0.2032, 0.2818, 0.4631, 0.4149, 0.5169, 0.2632, 0.3475, 0.3724, 0.6640,
        0.6112, 0.6850, 0.6612, 0.7272, 0.5632, 0.6614, 0.7050, 0.7442, 0.6821,
        0.7345, 0.6863, 0.7252, 0.7074, 0.7774, 0.8096, 0.7619, 0.8186, 0.8498,
        0.7911, 0.8644, 0.8664, 0.8763, 0.8517, 0.8644, 0.8838, 0.8260, 0.8618,
        0.8350, 0.8449, 0.8744, 0.8706, 0.8815, 0.8539, 0.8642, 0.8554, 0.8876,
        0.8461, 0.8431, 0.8802, 0.8582, 0.8556, 0.8776, 0.8752, 0.8769, 0.8493,
        0.8104, 0.8150, 0.8552, 0.8692, 0.7995, 0.8389, 0.8199, 0.7802, 0.8257,
        0.7912, 0.8084, 0.7898, 0.8173, 0.7951, 0.8218, 0.7962, 0.7091, 0.7330,
        0.7641, 0.7909, 0.7292, 0.7359, 0.7818, 0.6702, 0.7915, 0.7447, 0.7960,
        0.7757, 0.7038, 0.6161, 0.7225, 0.7421, 0.6685, 0.6090, 0.7549, 0.7711,
        0.5999, 0.7452, 0.7245, 0.7226, 0.5585, 0.5709, 0.5773, 0.6863, 0.6040,
        0.6319, 0.5801, 0.6103, 0.5330, 0.5982, 0.7244, 0.6303, 0.7311, 0.5641,
        0.7218, 0.7311, 0.7061, 0.6656, 0.6718, 0.6700, 0.5839, 0.7529])

In [None]:
prev_result = th.tensor([0.6038, 0.6022, 0.6533, 0.5899, 0.5862, 0.5246, 0.6919, 0.7083, 0.6855,
        0.5831, 0.7160, 0.7125, 0.5086, 0.7146, 0.5711, 0.7261, 0.6994, 0.5769,
        0.7188, 0.6108, 0.6630, 0.7658, 0.7691, 0.7029, 0.7840, 0.5381, 0.6531,
        0.4957, 0.5039, 0.6698, 0.7042, 0.7736, 0.6810, 0.7888, 0.7157, 0.7675,
        0.7346, 0.7387, 0.8387, 0.8219, 0.7806, 0.8110, 0.7640, 0.7747, 0.5574,
        0.3897, 0.8071, 0.3525, 0.8218, 0.6098, 0.7711, 0.7488, 0.3996, 0.8158,
        0.7387, 0.7951, 0.7923, 0.7800, 0.3882, 0.7349, 0.4754, 0.5051, 0.7427,
        0.7569, 0.7479, 0.7472, 0.7557, 0.8072, 0.7723, 0.7666, 0.7124, 0.8153,
        0.7376, 0.7952, 0.7789, 0.7329, 0.6864, 0.8006, 0.7526, 0.6557, 0.6733,
        0.7175, 0.7298, 0.7309, 0.6075, 0.6383, 0.6851, 0.7332, 0.7856, 0.7591,
        0.6971, 0.6597, 0.7490, 0.7641, 0.6538, 0.7678, 0.7675, 0.7534, 0.8026,
        0.7546, 0.8574, 0.7855, 0.8410, 0.8723, 0.8492, 0.8892, 0.8797, 0.9030,
        0.8757, 0.8733, 0.8815, 0.9148, 0.8907, 0.9062, 0.9198, 0.9015, 0.9034,
        0.8998, 0.9150, 0.9094, 0.9036, 0.9206, 0.9217, 0.9163, 0.9090, 0.9115,
        0.9080, 0.8677, 0.8948, 0.8390, 0.8721, 0.8394, 0.7821, 0.7512, 0.8244,
        0.6840, 0.7629, 0.6793, 0.6120, 0.7819, 0.6359, 0.4651, 0.6789, 0.6329,
        0.4106, 0.5069, 0.5810, 0.5453, 0.2712, 0.4755, 0.5026, 0.2056, 0.1994,
        0.3580, 0.3943, 0.1953, 0.2001, 0.1954, 0.1993, 0.1859, 0.2061, 0.1948,
        0.2600, 0.2611, 0.2771, 0.1808, 0.1809, 0.1874, 0.1830, 0.3754, 0.2860,
        0.2020, 0.2363, 0.1868, 0.3914, 0.3246, 0.2546, 0.1984, 0.2782, 0.2829,
        0.2892, 0.2208, 0.2398, 0.2036, 0.2278, 0.2828, 0.2960, 0.3128, 0.4473,
        0.4172, 0.4128, 0.3684, 0.6090, 0.6412, 0.4830, 0.6195, 0.7190, 0.5789,
        0.6003, 0.6492, 0.7318, 0.7896, 0.6940, 0.6497, 0.7301, 0.7389, 0.7624,
        0.8203, 0.7260, 0.8323, 0.7390, 0.8280, 0.8536, 0.8091, 0.7835, 0.8659,
        0.8059, 0.8223, 0.8118, 0.8241, 0.8713, 0.8154, 0.8140, 0.8804, 0.8388,
        0.8775, 0.8507, 0.8724, 0.8441, 0.8627, 0.8795, 0.8689, 0.8465, 0.8750,
        0.8576, 0.8084, 0.8413, 0.7996, 0.8668, 0.8416, 0.8356, 0.8592, 0.7779,
        0.8173, 0.8011, 0.8190, 0.7699, 0.7498, 0.7572, 0.7828, 0.8036, 0.8047,
        0.7763, 0.7070, 0.7523, 0.8161, 0.6967, 0.6977, 0.7734, 0.7137, 0.7823,
        0.7359, 0.7211, 0.6845, 0.6649, 0.6226, 0.6707, 0.7272, 0.6937, 0.6008,
        0.6652, 0.7354, 0.7703, 0.7602, 0.6475, 0.6972, 0.7220, 0.7421, 0.6685,
        0.5831, 0.7468, 0.7078, 0.6843, 0.7388, 0.7226, 0.7311, 0.7050, 0.5639,
        0.5454, 0.5564, 0.7169, 0.6832, 0.7301, 0.5666, 0.5484, 0.5403, 0.5677,
        0.6886, 0.5410, 0.6891, 0.5868, 0.6185, 0.5629, 0.5961, 0.5379, 0.7162,
        0.6845, 0.5696, 0.6343, 0.7316, 0.7163, 0.4922, 0.5668, 0.5747, 0.5991,
        0.5757, 0.6893, 0.5816, 0.6690, 0.7569, 0.7392, 0.5306, 0.5860, 0.7674,
        0.5699, 0.6656, 0.7817, 0.7988, 0.7290, 0.7338, 0.8133, 0.7710, 0.5711,
        0.8104, 0.8091, 0.7833, 0.7311, 0.7298, 0.7694, 0.8071, 0.6614, 0.7929,
        0.8326, 0.8063, 0.7911, 0.8351, 0.6956, 0.8628, 0.6626, 0.8525, 0.7082,
        0.8053, 0.8403, 0.8557, 0.8585, 0.8402, 0.8509, 0.8438, 0.8448, 0.6585,
        0.8391, 0.7686, 0.6285, 0.8214, 0.7245, 0.7634, 0.8127, 0.7942, 0.6204,
        0.5062, 0.7820, 0.7645, 0.7598, 0.7337, 0.4906, 0.6472, 0.7270, 0.4639,
        0.4457, 0.6779, 0.5782, 0.4945, 0.5674, 0.5448, 0.4162, 0.5718, 0.4532,
        0.4873, 0.4213, 0.6293, 0.4545, 0.4239, 0.4283, 0.6730, 0.4491, 0.5919,
        0.4395, 0.4037, 0.5692, 0.3573, 0.4534, 0.4628, 0.3122, 0.3137, 0.2710,
        0.1790, 0.2602, 0.2503, 0.2142, 0.1613, 0.1830, 0.1258, 0.1316, 0.1467,
        0.1506, 0.1493, 0.1049, 0.1113, 0.1103, 0.1022, 0.1481, 0.1040, 0.1057,
        0.1495, 0.1253, 0.1984, 0.1958, 0.1524, 0.1526, 0.1953, 0.2703, 0.1979,
        0.2469, 0.2947, 0.4925, 0.3734, 0.3778, 0.6430, 0.7345, 0.7001, 0.6227,
        0.7715, 0.7097, 0.7587, 0.6516, 0.7799, 0.7252, 0.8130, 0.8391, 0.7717,
        0.8391, 0.8563, 0.8041, 0.8722, 0.8208, 0.8565, 0.8514, 0.8368, 0.8859,
        0.8838, 0.8805, 0.8644, 0.8719, 0.8408, 0.8741, 0.8528, 0.8594, 0.8334,
        0.8436, 0.8328, 0.8693, 0.8790, 0.8712, 0.8251, 0.8845, 0.8799, 0.8767,
        0.8468, 0.8326, 0.8695, 0.8822, 0.8826, 0.8723, 0.8509, 0.8740, 0.8161,
        0.8462, 0.8438, 0.7879, 0.8402, 0.7297, 0.7879, 0.8048, 0.7780, 0.6034,
        0.7709, 0.6330, 0.5583, 0.7029, 0.5541, 0.4954, 0.6983, 0.4006, 0.6608,
        0.5747, 0.5871, 0.3921, 0.3986, 0.5611, 0.3900, 0.3492, 0.3149, 0.3165,
        0.2739, 0.2744, 0.4302, 0.2410, 0.3038, 0.2725, 0.2665, 0.2838, 0.3668,
        0.3240, 0.2466, 0.2235, 0.2458, 0.2133, 0.2646, 0.2101, 0.3296, 0.2197,
        0.2334, 0.2641, 0.2815, 0.2089, 0.3686, 0.2280, 0.2742, 0.2675, 0.2401,
        0.3756, 0.4149, 0.2446, 0.2773, 0.4150, 0.3507, 0.4076, 0.3804, 0.3339,
        0.4671, 0.3311, 0.3102, 0.3714, 0.5413, 0.3315, 0.3512, 0.3378, 0.4783,
        0.3935, 0.3459, 0.5959, 0.5071, 0.4262, 0.6024, 0.5607, 0.6103, 0.5381,
        0.4648, 0.5184, 0.4787, 0.4109, 0.5861, 0.5560, 0.5510, 0.6221, 0.4849,
        0.6549, 0.6750, 0.5154, 0.5627, 0.5766, 0.6771, 0.6165, 0.5519, 0.6812,
        0.6687, 0.6752, 0.5458, 0.6400, 0.6924, 0.5203, 0.5207, 0.6263, 0.5622,
        0.7196, 0.6466, 0.6182, 0.6645, 0.7056, 0.6604, 0.6535, 0.7201])

In [None]:
met1 = Spectra(wavelengths=wavelengths1, data=curr_result.numpy()[:301])
met2 = Spectra(wavelengths=wavelengths1, data=curr_result.numpy()[301:])

In [None]:
met1.plot(color='red')
met2.plot()
plt.show()

In [None]:
coefficients = np.polynomial.polynomial.polyfit(met1.wavelengths(), met1.data(), 8)
polynomial = np.polynomial.Polynomial(coefficients)
met1.plot()
plt.plot(wavelengths1, polynomial(wavelengths1))

In [None]:
from scipy.interpolate import CubicSpline
spline = CubicSpline(met1.wavelengths(), met1.data())
met1.plot()
plt.plot(wavelengths1, spline(wavelengths1), '-', label='Cubic Spline')
plt.show()

In [None]:
print(tetrachromat.observe(met1))
print(tetrachromat.observe(met2))

In [None]:
62-53.9

In [None]:
met3 = Spectra(wavelengths=wavelengths1, data=prev_result.numpy()[:301])
met4 = Spectra(wavelengths=wavelengths1, data=prev_result.numpy()[301:])
print(tetrachromat.observe(met3))
print(tetrachromat.observe(met4))

In [None]:
62-54.68

In [None]:
met4.plot(color='red')
met3.plot()
plt.show()

# Polynomial 
The results look sooo close too a polynomial that I'd want to try it

In [None]:
from models import PolynomialRootsModel, full_spectra_dual_optimization_metric
import torch.optim as optim

Degree 8 seems reasonable from earlier plot. 

In [None]:
wavelengths1 = np.array(
    range(400, 701, 1))

In [None]:
instance.roots1

In [None]:
degree = 8
instance = PolynomialRootsModel((wavelengths1 - 400 - 150) / 150, degree)

plt.scatter(wavelengths1, instance()[:301].detach().numpy())

In [None]:
instance.coeff1

In [None]:
degree = 8

instance = PolynomialRootsModel((wavelengths1 - 400 - 150) / 150, degree)
optimizer = optim.SGD(instance.parameters(), lr=0.01)

power_series = th.stack([th.tensor(wavelengths1) ** i for i in range(degree)]).to(th.float32)
observer_matrix = th.from_numpy(tetrachromat.get_sensor_matrix(wavelengths1)).to(th.float32)
whitepoint = th.from_numpy(tetrachromat.get_whitepoint(wavelengths1)).to(th.float32)

epochs_values = []
loss_values = []


for epoch in range(1, 10000 + 1):
    spectras = instance()
    loss = full_spectra_dual_optimization_metric(
        spectras, observer_matrix, whitepoint, 2
    )
    if np.isnan(loss.item()): 
        print("loss is nan")
        break

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if not (epoch % 1000):
        ls = loss.item()
        print(f"After {epoch} epochs, loss is {ls}.")
        loss_values.append(ls)
        epochs_values.append(epoch)


plt.plot(epochs_values, loss_values, label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Curve")
plt.legend()
plt.grid(True)
plt.show()

instance()

In [None]:
z1 = Spectra(wavelengths=wavelengths1, data=instance().detach().numpy()[:301])
z2 = Spectra(wavelengths=wavelengths1, data=instance().detach().numpy()[301:])

In [None]:
z1.plot()
z2.plot()

ok we should probably do $b + a(x - k1)(x - k2)..(x -kk)$ model instead lol

In [None]:
from models import PowerSeriesModel

In [None]:
(wavelengths1 - 400) / 300

In [None]:
instance = PowerSeriesModel((wavelengths1 - 400) / 300, 8)

In [None]:
instance()

In [None]:
jessica_transition1 = np.zeros(700)
jessica_transition2 = np.zeros(700)

jessica_transition1[448:548] = 1
jessica_transition1[590:700] = 1

jessica_transition2[0:402] = 1
jessica_transition2[532:610] = 1

In [None]:
jessica_transition2[-301:].shape

In [None]:
wavelengths = np.arange(390, 701, 1)
wavelengths.shape

In [None]:
t1 = Spectra(wavelengths=wavelengths, data=jessica_transition1[-311:])
t2 = Spectra(wavelengths=wavelengths, data=jessica_transition2[-311:])
s_cone.plot()
t1.plot()
t2.plot()

In [None]:
tetrachromat.observe(t1)

In [None]:
tetrachromat.observe(t2)