In [None]:
import numpy as np
import torch
import mm3dtestdata as mm3d
import fusecam
from fusecam.geometric import space
from fusecam.geometric import embedplane
from fusecam.geometric import interpolate
from fusecam.manipimg import rotate_tensor_cube

from fusecam.aiutil import train_scripts
from fusecam.aiutil import ensembling

import matplotlib.pyplot as plt
import einops

from torch.utils.data import TensorDataset, DataLoader

from dlsia.core.networks import sms3d
from dlsia.core import helpers
from dlsia.viz_tools import draw_sparse_network


import torch.nn as nn
import torch.optim as optim



First we need to build test data, low res and high res.

In [None]:
scale = 64
border = 10
radius = 10

sigma_low = 3.0
sigma_high = 1.0

In [None]:
obj = mm3d.balls_and_eggs(scale=scale, border=border, radius=radius, k0=1.0)
_, instance_map_0, class_map_0 = obj.fill()

In [None]:
cmap = plt.cm.get_cmap('Set1', 3)
plt.imshow(class_map_0[32,...], cmap=cmap, interpolation='none')
cbar = plt.colorbar(ticks=[0,1,2,3] )#np.arange(np.min(0), np.max(3) + 1))
plt.show()

In [None]:
class_map_low = mm3d.blur_it(class_map_0, sigma=sigma_low)
class_map_high = mm3d.blur_it(class_map_0, sigma=sigma_high)

In [None]:
tomo_class_0 = np.array([0])
tomo_class_1 = np.array([0.05])
tomo_class_2 = np.array([0.25])
tomo_class_3 = np.array([0.35])
class_actions_tomo = np.column_stack([tomo_class_0,
                                      tomo_class_1,
                                      tomo_class_2,
                                      tomo_class_3]).T

plt.bar( ["Class 0", "Class 1", "Class 2", "Class 3"],class_actions_tomo.ravel() )
plt.title("Density")
plt.show()

In [None]:
low_map = mm3d.compute_weighted_map(class_map_low, class_actions_tomo)
high_map = mm3d.compute_weighted_map(class_map_high, class_actions_tomo)

low_map = low_map + mm3d.noise(low_map, 0.01, 0.0)
high_map = high_map + mm3d.noise(high_map, 0.01, 0.0)

In [None]:
plt.imshow(low_map[0, :,:,scale//2])
plt.show()

plt.imshow(high_map[0, :,:,scale//2])
plt.show()

Now that we have data, I will use have to make the geometric objects

In [None]:
space_object = space.SpatialVolumeMetric(origin=(0,0,0),
                                         step_size=(1,1,1),
                                         orientation = torch.eye(3),
                                         translation = (0,0,0),
                                        )
plane_object = space.SpatialPlaneMetric(origin=(0,0),
                                         step_size=(1,1),
                                         orientation = torch.eye(2),
                                         translation = (0,0))

Now we have the two geometric objects, I want define a plane and get stuff going

In [None]:
u = torch.linspace(0,scale-1,scale)
U,V = torch.meshgrid(u,u, indexing='ij')
UV = torch.concat([U.flatten().reshape(1,-1), V.flatten().reshape(1,-1)]).T

x = torch.linspace(0,scale-1,scale)
X,Y,Z = torch.meshgrid(x,x,x, indexing="ij")
XYZ = torch.concat([X.flatten().reshape(1,-1), Y.flatten().reshape(1,-1), Z.flatten().reshape(1,-1),]).T 
print(UV.shape, XYZ.shape)

In [None]:
aligner_1 = embedplane.Plane3DAligner(
    normal=[0.0, 0.00, 1.0], 
    point_on_plane=[scale//2, scale//2, scale//2]
)
point_on_plane_2D_1 = (scale//2,scale//2)
aligned_points_1 = aligner_1.align_points_to_3d(UV, point_on_plane_2D_1, rotation_angle=0)

aligner_2 = embedplane.Plane3DAligner(
    normal=[0.0, -1.0, 0.0], 
    point_on_plane=[scale//2, scale//2, scale//2]
)
point_on_plane_2D_2 = (scale//2,scale//2)
aligned_points_2 = aligner_2.align_points_to_3d(UV, point_on_plane_2D_2, rotation_angle=0)




In [None]:
indices_1, near_dist_1 = interpolate.find_nearest(XYZ, aligned_points_1, 5)
weights_1 = interpolate.compute_weights(near_dist_1, power=3.0, cutoff=2.0)

indices_2, near_dist_2 = interpolate.find_nearest(XYZ, aligned_points_2, 5)
weights_2 = interpolate.compute_weights(near_dist_2, power=3.0, cutoff=2.0)

5

In [None]:
funct_1 = interpolate.inverse_distance_weighting_with_weights(torch.Tensor(high_map.flatten()), 
                                                                         indices_1, 
                                                                         weights_1)
funct_2 = interpolate.inverse_distance_weighting_with_weights(torch.Tensor(high_map.flatten()), 
                                                                         indices_2, 
                                                                         weights_2)



In [None]:
funct_1= einops.rearrange(funct_1, "(X Y) -> X Y ",X=scale, Y=scale)
plt.imshow(funct_1.numpy() )
plt.colorbar()
plt.show()

plt.imshow(high_map[0, :,:,scale//2])
plt.colorbar()
plt.show()

funct_2 = einops.rearrange(funct_2, "(X Y) -> X Y",X=scale, Y=scale)
plt.imshow(funct_2.numpy() )
plt.colorbar()
plt.show()


plt.imshow( high_map[0, :,scale//2,:]) 
plt.colorbar()
plt.show()




In [None]:
plt.imshow(funct_1.numpy() - high_map[0, :,:,scale//2])
plt.colorbar()
plt.show()

plt.imshow(funct_2.numpy() - high_map[0, :,scale//2,:])
plt.colorbar()
plt.show()


Build Data Loader

In [None]:
my_3d_maps = torch.concat([torch.Tensor(low_map).unsqueeze(0),
                           torch.Tensor(low_map).unsqueeze(0)])
my_2d_maps = torch.concat([torch.Tensor(high_map[0, :,:,32]).flatten().unsqueeze(0),
                           torch.Tensor(high_map[0, :,32,:]).flatten().unsqueeze(0)])
my_weights = torch.concat([weights_1.unsqueeze(0), weights_2.unsqueeze(0) ])
my_indices = torch.concat([indices_1.unsqueeze(0), indices_2.unsqueeze(0)])
                           

In [None]:
my_data = TensorDataset(my_3d_maps, my_2d_maps, my_weights, my_indices)                          
data_loader = DataLoader(my_data, batch_size=1) 

In [None]:
n_networks = 7
networks = ensembling.construct_3dsms_ensembler(n_networks=n_networks,
                                                in_channels=1,
                                                out_channels=1,
                                                layers = 20,
                                                alpha=0.00,
                                                gamma=0.00,
                                                hidden_channels=[5],
                                                parameter_bounds=[80000,90000]
                                               )

In [None]:
for net in networks:
    print( helpers.count_parameters(net) )
    a,b,c = draw_sparse_network.draw_network(net)
    
    

In [None]:
for net in networks:
    loss_function = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0051
    train_scripts.train_volume_on_slice()net, 
            loss_function, 
            optimizer, 
            data_loader, 
            500, 
            interpolate.inverse_distance_weighting_with_weights, device='cpu'
    
    

In [None]:
def train_model(net, loss_function, optimizer, dataloader, num_epochs, interpolate_function, device='cuda:0'):
    net.to(device)  # Move the network to the specified device
    net.train()  # Set the network to training mode

    for epoch in range(num_epochs):
        running_loss = 0.0

        for batch in dataloader:
            img_tensor_3d, flat_2d_tensor, weights, indices = [item.to(device) for item in batch]

            # Forward pass
            outputs = net(img_tensor_3d)
            loss = 0.0
            for img3d, img2d, ws, idx in zip(outputs, flat_2d_tensor, weights, indices):
                img_flat = img3d.flatten()
                interp = interpolate_function(img_flat, idx, ws)
                not_nan_sel = ~torch.isnan(interp)

                # Compute loss
                loss += loss_function(interp[not_nan_sel], img2d[not_nan_sel])

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Print statistics
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(dataloader)}")

In [None]:
# Create an L1 loss function
loss_function = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.0051)  # You can adjust the learning rate as needed
train_model(net, 
            loss_function, 
            optimizer, 
            data_loader, 
            500, 
            interpolate.inverse_distance_weighting_with_weights, device='cpu')


In [None]:
with torch.no_grad():
    tmp3 = net.cpu()(torch.Tensor(low_map).unsqueeze(0))

In [None]:
m = (tmp1+tmp2+tmp3)/3.0
s = torch.sqrt((tmp1**2+tmp2**2+tmp3**2)/3.0 - m*m)

In [None]:
import napari

v = napari.view_image(low_map)
v.add_image(m.numpy()[0])
v.add_image(s.numpy()[0])
v.add_image(high_map)

In [None]:
plt.imshow(m.numpy()[0, 0,:,:,scale//2-15])
plt.colorbar()
plt.show()

plt.imshow(high_map[0,:,:,scale//2-15])
plt.colorbar()
plt.show()


plt.imshow(low_map[0,:,:,scale//2-15])
plt.colorbar()
plt.show()