# Synthetic dataset creation

Must have multiple variants for each class.

# Create discretisation recursively discretising the fundamnetal domain

3 initial points on the fundamental domain of the icosahedron and their orbits

In [None]:
import numpy as np

# Recursive function discretizing the sphere:
def discretize_fundamental_domain(vertices, depth):
    points = []
    d=0
    def point_and_vertices(current_vertices):
        a,b,c = current_vertices
        # point = (a+b+c)/np.linalg.norm(a+b+c)
        mid1 = (a+b)/np.linalg.norm(a+b)
        mid2 = (b+c)/np.linalg.norm(b+c)
        mid3 = (c+a)/np.linalg.norm(c+a)
        next_vertices = [(a, mid1, mid3), (b, mid1, mid2), (c, mid2, mid3), (mid1, mid2, mid3)]
        return next_vertices
    triangles = vertices
    for i in range(depth):
        temp = []
        for triangle in triangles:
            temp += point_and_vertices(triangle)
        triangles = temp
        print(len(triangles))
    
    for triangle in triangles:
        a, b, c = triangle
        point = (a+b+c)/np.linalg.norm(a+b+c)
        points.append(point)

    
    return points


    
phi = (np.sqrt(5)+1)/2
vertices = [(np.array([1,phi,0]), np.array([0,1,phi]), np.array([-1,phi,0]))]

points = discretize_fundamental_domain(vertices, 6)
num_initial_points = len(points)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objs as go
from scipy.spatial.transform import Rotation as R

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# phi = (np.sqrt(5)+1)/2 #golden angle helps define verticies of icosahedron
icosahedral_group = R.create_group('I') #rotation group only, 60 elements
print(f'Number of rotations in icosahedral group: {len(icosahedral_group)}')
rotation_matrices = icosahedral_group.as_matrix()
# d=1/3
# initial_point = (1-d)*np.array([1,phi,0])+d*np.array([0,1+2*phi,phi])
# #initial_point = np.array([1,1+3*phi,phi]) #This point does not lie on any axes of symmetry
# second_point = np.array([1,1+phi,phi])
# second_point = second_point/np.linalg.norm(second_point)
# third_point = np.array([1,phi,0])
# third_point = third_point/np.linalg.norm(third_point)
# initial_point = initial_point/np.linalg.norm(initial_point)
# points = [initial_point, second_point, third_point]

orbit = np.array([p@M for p in points for M in rotation_matrices])

print(f'Length of orbit: {len(orbit)}')
unique_orbit = np.unique(np.array(orbit), axis=0) #check uniqueness, in case of accidentally generating a subgroup
print(f'Unique points in orbit: {np.array(unique_orbit).shape[0]}')
x, y, z = np.array(orbit).T

In [None]:
# Edit opacity and size of points

trace = go.Scatter3d(
    x=x, 
    y=y,
    z=z,
    mode='markers',
    marker=dict(
        size=0.4,
        opacity=0.3,
    )
)

layout = go.Layout(
    scene=dict(
        xaxis=dict(nticks=4, range=[-1,1],),
        yaxis=dict(nticks=4, range=[-1,1],),
        zaxis=dict(nticks=4, range=[-1,1],),
        aspectratio=dict(x=1, y=1, z=1),
    )
)

fig = go.Figure(data=[trace], layout=layout)

fig.show()

# Shape functions with varying parameters:

In [None]:
# Great band describes a band around the equator of the sphere
def great_band(p, width = np.pi/6):
    """
    input p is already normalised.
    """
    theta = np.arctan2(p[1], p[0])  # Azimuthal angle
    phi = np.arccos(p[2])  # Polar angle
    # Return 1 if the point is within the band, else return 0
    if (phi > (np.pi / 2 - width)) and (phi < (np.pi / 2 + width)):
        return 1
    else:
        return 0
# Checkerboard pattern
def checkerboard(p,n1=4,n2=2):
    """
    normalised point p
    n1: divisions of theta
    n2: divisions of phi
    """

    theta = np.arctan2(p[1], p[0])  # Azimuthal angle
    phi = np.arccos(p[2])  # Polar angle
    theta = theta % (2 * np.pi)
    
    # Calculate the width of each division
    width_theta = (2 * np.pi) / n1
    width_phi = np.pi / n2
    
    # Determine the index of the division each angle falls into
    index_theta = int(theta // width_theta)
    index_phi = int(phi // width_phi)
    
    # Return 1 if the sum of the indices is even (checkerboard pattern), else return 0
    return 1 if (index_theta + index_phi) % 2 == 0 else 0
# Inverse function of great band
def polar_caps(p, width=np.pi/3):
    phi = np.arccos(p[2])  # Polar angle
    if phi < width or phi > (np.pi - width):
        return 1 
    return 0

# different type of shape function, defined on the matrix only not a shape in space

def oneblob(n, num_initial_points):
    shapematrix = torch.zeros((num_initial_points, 60))
    column = n // num_initial_points
    residue = n % num_initial_points
    for j in range(column):
        shapematrix[:, j] = torch.ones(num_initial_points)
    for i in range(residue):
        shapematrix[i, column] = 1
    return shapematrix

# A function mapping to [0,1] as opposed to binary output.
def star_pattern(p, n_arms=5, contrast=10):
    """
    Generates a star pattern on the sphere.

    Args:
    p (array): normalized point on the sphere [x, y, z].
    n_arms (int): number of star arms.
    contrast (float): controls the sharpness of the star edges.

    Returns:
    float: a value between 0 and 1 representing the color intensity.
    """
    
    # Convert Cartesian coordinates to spherical coordinates
    theta = np.arctan2(p[1], p[0])  # Azimuthal angle
    phi = np.arccos(p[2])  # Polar angle
    theta = theta % (2 * np.pi)
    
    # Create star pattern based on theta
    angle_per_arm = 2 * np.pi / n_arms
    theta_mod = theta % angle_per_arm
    theta_effect = np.abs(np.cos(n_arms * theta)) ** 0.4  # Modulate with cos to create sharper edges

    # Combine theta effect with phi for a vertical dimension modulation
    phi_effect = np.abs(np.cos(contrast * phi)) ** 0.5# Create oscillations in the polar angle
    #return theta_effect
    return (theta_effect * phi_effect) ** 2  # Combine both effects and square to enhance contrast
    

In [None]:
# Example dataset creation for torch, including random rotations done by rolling the shape matrix
class ShapeDataset(Dataset):
    def __init__(self, shapes, labels, rotate = False):
        """
        shapes: A torch tensor of shape (num_shapes, 3, 60), representing binary matrices for each shape.
        labels: A torch tensor of shape (num_shapes,) with integer labels for each shape.
        rotate: A boolean indicating whether to randomly rotate (shift) the orbits.
        """
        self.shapes = shapes
        self.labels = labels
        self.rotate = rotate
    def __len__(self):
        return len(self.shapes)
    def __getitem__(self, idx):
        shape = self.shapes[idx]
        if self.rotate:
            shift = np.random.randint(shape.shape[1])
            shape = np.roll(shape, shift, axis=1)
        shape_flat = torch.tensor(shape.flatten(), dtype=torch.float32)
        label = self.labels[idx]
        return shape_flat, label
    

In [None]:
# List of dictionaries describing the variations for each shape function
checkerboard_params = [{'n1': n1, 'n2': n2} for n1 in range(1,10) for n2 in range(2,10)]
great_band_params = [{'width': bw} for bw in np.linspace(np.pi/15,np.pi/3, 80)]
polar_caps_params = [{'width': bw} for bw in np.linspace(np.pi/15,np.pi/2.2,80)]
star_params = [{'n_arms': n, 'contrast': i} for n in range(1,5) for i in range(1,2)]

def generate_variations(orbit, shape_fn, params_list):
    """
    Generate variations of a shape based on a list of parameter dictionaries.
    
    Args:
    - orbits: A numpy array of shape (num_shapes, 3, 60), representing the discretized orbits.
    - shape_fn: The shape function to apply (e.g., checkerboard or great_band).
    - params_list: A list of dictionaries, each containing a set of parameters for the shape function.
    
    Returns:
    - A list of numpy arrays, each representing a variation applied to all orbits.
    """
    variations = []
    for params in params_list:
        # Apply shape function to each point in each orbit with current parameters
        variation = np.array([shape_fn(point, **params) for point in orbit]).reshape(num_initial_points,60)
        variations.append(variation)
    return torch.tensor(variations)

checkerboard_variations = generate_variations(orbit, checkerboard, checkerboard_params)
great_band_variations = generate_variations(orbit, great_band, great_band_params)
polar_caps_variations = generate_variations(orbit, polar_caps, polar_caps_params)
star_variations = generate_variations(orbit, star_pattern, star_params)


checkerboard_labels = torch.zeros(checkerboard_variations.shape[0], dtype=torch.long)
great_band_labels = torch.ones(great_band_variations.shape[0],dtype=torch.long)
polar_caps_labels = 2*torch.ones(polar_caps_variations.shape[0], dtype = torch.long)

print(f"number of checkerboards: {len(checkerboard_variations)}, band: {len(great_band_variations)}, Polar caps:{len(polar_caps_variations)}")

# Plotting and visualising
This is an interactive plot which makes defining spherical shapes far easier.

For the binary functions be sure to set x, y and z as follows            
x=x[example_shape == 1],
y=y[example_shape == 1],
z=z[example_shape == 1],

to only show the 1s.
Else keep x,y,x = x,y,z as already implemented.

In [None]:
from plotly.subplots import make_subplots
x, y, z = orbit.T
#shape_type = [torch.tensor(oneblob(20, num_initial_points))]*4
shape_type = star_variations
# Number of variations to plot
num_variations = 4

# Create a 2x2 subplot layout
fig = make_subplots(
    rows=2, cols=2,
    specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}],
           [{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
    subplot_titles=[f"Variation {i+1}" for i in range(num_variations)]
)

# Select 4 random variations from great_band_variations
random_indices = np.random.choice(len(shape_type), num_variations, replace=False)

for i, idx in enumerate(random_indices, start=1):
    example_shape = shape_type[idx].view(num_initial_points*60)
    row = (i-1)//2 + 1
    col = (i-1)%2 + 1
    
    fig.add_trace(
        go.Scatter3d(
            # x=x[example_shape == 1],
            # y=y[example_shape == 1],
            # z=z[example_shape == 1],
            x=x,
            y=y,
            z=z,
            
            mode='markers',
            # marker=dict(size=1, opacity=1),
            marker=dict(
            size=1,
            color=example_shape,  # Use the calculated float values for color
            colorscale='Plasma',  # This is an example of a color scale
            colorbar=dict(title='Pattern Intensity'),  # Optional: Adds a color bar to your plot
            opacity=1
        ),
            name=""  # Naming each subplot with its variation index
        ),
        row=row, col=col
    )

# Update layout for a better view
    fig.update_layout(
        height=800, width=800,
        title_text="4 Random Variations of the Example Shape",
        scene=dict(
            xaxis=dict(nticks=4, range=[-1,1],),
            yaxis=dict(nticks=4, range=[-1,1],),
            zaxis=dict(nticks=4, range=[-1,1],),
            aspectratio=dict(x=1, y=1, z=1),
        )
    )

fig.show()

# Outline for testing types of model
Examples: 
no rotation invariance model
rotation invariant model
rotated dataset model


In [None]:
all_variations = torch.cat([checkerboard_variations, great_band_variations, polar_caps_variations], dim=0)
all_labels = torch.cat([checkerboard_labels, great_band_labels, polar_caps_labels], dim=0)

train_dataset = ShapeDataset(all_variations[::2], all_labels[::2])
test_dataset = ShapeDataset(all_variations[1::2], all_labels[1::2], rotate = False)

batch_size = 16

train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle=True)


In [None]:
import torch.nn.functional as F
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.input = nn.Linear(num_initial_points*60, 100)
        self.hidden = nn.Linear(100, 100)
        self.output = nn.Linear(100, 3)
    
    def forward(self, x):
        x = F.relu(self.input(x))
        x = F.relu(self.hidden(x))
        return self.output(x)

In [None]:
net = Classifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)
num_epochs = 50

def calculate_accuracy(y_pred, y_true):
    predicted_classes = torch.argmax(y_pred, dim=1)
    correct_predictions = torch.eq(predicted_classes, y_true).sum().item()
    accuracy = correct_predictions / y_true.shape[0]
    return accuracy

train_losses = []
test_losses = []
train_error_rates = []
test_error_rates = []

for epoch in range(num_epochs):
    net.train()
    train_loss = 0.0
    train_accuracy = 0.0
    for inputs, targets in train_dataloader:
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_accuracy += calculate_accuracy(outputs, targets)
    
    train_losses.append(train_loss / len(train_dataloader))
    train_error_rates.append(1 - (train_accuracy / len(train_dataloader)))
    
    net.eval()
    test_loss = 0.0
    test_accuracy = 0.0
    with torch.no_grad():
        for inputs, targets in test_dataloader:
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            test_accuracy += calculate_accuracy(outputs, targets)
    test_accuracy = test_accuracy / len(test_dataloader)

    test_losses.append(test_loss)
    test_error_rates.append(1 - (test_accuracy))
    
    print(f"Epoch {epoch+1}, Train Loss: {train_losses[-1]}, Test Loss: {test_losses[-1]}, Test Error Rate: {test_error_rates[-1]}, Accuracy: {test_accuracy}")

import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.title('Loss over epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_error_rates, label='Train Error Rate')
plt.plot(test_error_rates, label='Test Error Rate')
plt.title('Error Rate over epochs')
plt.xlabel('Epoch')
plt.ylabel('Error Rate')
plt.legend()

plt.show()