<h1> <center> Clustered Ti Orientations </center></h1>

This notebook illustrates clustering of orientations while accounting for symmettry. It begins by importing the relevant data. It then illustrates how the distance matrix can be formed. This stage can also be skipped and the D matrix loaded directly for this example notbook. The clustering is then conducted using the DBSCAN algorithm and the data plotted.

In [None]:
%matplotlib qt5

# Important external dependencies
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN

# orix dependencies (tested with orix 0.1.1)
from orix.quaternion.orientation import Orientation, Misorientation
from orix.quaternion.rotation import Rotation
from orix.quaternion.symmetry import D6
from orix.quaternion.orientation_region import OrientationRegion
from orix.vector.neo_euler import AxAngle
from orix.vector import Vector3d
from orix import plot

# Colorisation
from skimage.color import label2rgb
from matplotlib.colors import to_rgb, to_hex
MPL_COLORS_RGB = [to_rgb('C{}'.format(i)) for i in range(10)]
MPL_COLORS_HEX = [to_hex(c) for c in MPL_COLORS_RGB]

# Animation
import matplotlib.animation as animation

# Visualisation
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from matplotlib.lines import Line2D


plt.rc('font', size=6)

In [None]:
# Load data (euler angles, in degrees, Bunge convention) from CTF file
filepath = './data/Ti_orientations.ctf'
dat = np.loadtxt(filepath, skiprows=1)[:, :3]

In [None]:
ori = Orientation.from_euler(np.radians(dat)) # converts to quarternion-like representation
ori = ori.reshape(381,507) #reshape to the correct spatial dimensions
ori = ori[-100:,:200] #keep a subset of the data to speed process up ([-100:,:200])
fundamental_region = OrientationRegion.from_symmetry(D6) #define fundemental zone
print(ori.size)

<h2><center> Create D, via a correct approach incorperating symmetry </h2></center>

**Option 1: quick, will wreck your RAM**

Computes every possibility in a single tensor, then minimises.

In [None]:
# Compute misorientations
confirm = input('Are you sure? (y/n) ')
if confirm == 'y':
    misori_equiv = D6.outer(~ori).outer(ori).outer(D6)
    D = misori_equiv.angle.data.min(axis=(0, 2))

**Option 2: medium-speed, should be OK for RAM**

Precomputes one set of equivalent orientations.

In [None]:
confirm = input('This might take some time, are you sure? (y/n) ')
if confirm == 'y':
    from tqdm import tqdm_notebook
    D = np.zeros(ori.shape + ori.shape)
    D.fill(np.infty)

    OS2 = ori.outer(D6)

    for i in tqdm_notebook(range(ori.size)):
        idx = np.unravel_index(i, ori.shape)
        misori = D6.outer(~ori[idx]).outer(OS2)
        d = misori.angle.data.min(axis=(0, -1))
        D[idx[0], idx[1], ...] = np.minimum(D[idx[0], idx[1], ...], d)

**Option 3: slow, safe for RAM**

Iterates through every pair of orientations.

In [None]:
confirm = input('This will take some time, are you sure?(y/n) ')
if confirm == 'y':
    from itertools import combinations_with_replacement as icombinations
    from tqdm import tqdm_notebook

    D = np.empty(ori.shape + ori.shape)

    for i, j in tqdm_notebook(list(icombinations(range(ori.size), 2))):
        idx_1, idx_2 = np.unravel_index(i, ori.shape), np.unravel_index(i, ori.shape)
        o_1, o_2 = ori[idx_1], ori[idx_2]
        misori = D6.outer(~o_1).outer(o_2).outer(D6)
        d = misori.angle.data.min(axis=(0, 3))
        D[idx_1[0], idx_1[1], idx_2[0], idx_2[1]] = d
        D[idx_2[0], idx_2[1], idx_1[0], idx_1[1]] = d

**Option 4: Here is one we made earlier...**

In [None]:
filepath = './data/ori-distance((100, 200)).npy'
D = np.load(filepath)

<h2> <center> Clustering </h2> </center>

In [None]:
dbscan = DBSCAN(0.1, 40, metric='precomputed').fit(D.reshape(ori.size, ori.size))
print('Labels:', np.unique(dbscan.labels_))
labels = dbscan.labels_.reshape(ori.shape)
n_clusters = len(np.unique(dbscan.labels_)) - 1
print('Number of clusters:', n_clusters)

In [None]:
cluster_means = Orientation.stack([ori[labels == label].mean() for label in np.unique(dbscan.labels_)[1:]]).flatten()
cluster_means = cluster_means.set_symmetry(D6)

In [None]:
cluster_means.axis

In [None]:
# Recenter based on the matrix cluster and recompute means
ori_recentered = (~cluster_means[0]) * ori
ori_recentered = ori_recentered.set_symmetry(D6)
cluster_means_recentered = Orientation.stack([ori_recentered[labels == label].mean() for label in np.unique(dbscan.labels_)[1:]]).flatten()
cluster_means_axangle = AxAngle.from_rotation(cluster_means_recentered)

In [None]:
cluster_means_recentered.axis

## Plotting

In [None]:
# get label colors
colors = [to_rgb('C{}'.format(i)) for i in range(10)]
labels_rgb = label2rgb(labels, colors=colors)

# Create map and lines pointing to cluster means
mapping = labels_rgb
collection = Line3DCollection([((0, 0, 0), tuple(cm)) for cm in cluster_means_axangle.data], colors=colors)

In [None]:
# Main plot
fig = plt.figure(figsize=(3.484252, 3.484252))
gridspec = plt.GridSpec(1, 1, left=0, right=1, bottom=0, top=1, hspace=0.05)

ax_ori = fig.add_subplot(gridspec[0], projection='axangle', proj_type='ortho')
ax_ori.scatter(ori_recentered, c=labels_rgb.reshape(-1, 3), s=1)
ax_ori.plot_wireframe(fundamental_region, color='black', linewidth=0.5, alpha=0.1, rcount=181, ccount=361)
ax_ori.add_collection3d(collection)

ax_ori.set_axis_off()
ax_ori.set_xlim(-1, 1)
ax_ori.set_ylim(-1, 1)
ax_ori.set_zlim(-1, 1)
ax_ori.view_init(90, -30)

# Legend
handles = [
    Line2D(
        [0], [0], 
        marker='o', color='none', 
        label=i+1, 
        markerfacecolor=color, markersize=5
    ) for i, color in enumerate(colors[:n_clusters])
]

ax_ori.legend(handles=handles, loc='lower right', ncol=2, numpoints=1, labelspacing=0.15, columnspacing=0.15, handletextpad=0.05)

In [None]:
plt.close('all')
fig = plt.figure(figsize=(3.484252 * 2, 1.5 * 2))
gridspec = plt.GridSpec(1, 1, left=0, right=1, bottom=0, top=1, hspace=0.05)

ax_ori = fig.add_subplot(gridspec[0], projection='axangle', proj_type='ortho', aspect='equal')
ax_ori.scatter(ori_recentered, c=labels_rgb.reshape(-1, 3), s=1)
ax_ori.plot_wireframe(fundamental_region, color='black', linewidth=0.5, alpha=0.1, rcount=181, ccount=361)
# ax_ori.add_collection3d(collection)

ax_ori.set_axis_off()
ax_ori.set_xlim(-1, 1)
ax_ori.set_ylim(-1, 1)
ax_ori.set_zlim(-1, 1)
ax_ori.view_init(0, -30)

In [None]:
plt.close('all')
map_ax = plt.axes()
map_ax.imshow(mapping)

map_ax.set_xticks([])
map_ax.set_yticks([])