In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
def resize_images(image_filenames, target_size):
    resized_images = []
    for filename in image_filenames:
        img = Image.open(filename)
        img = img.resize(target_size, Image.ANTIALIAS)
        resized_images.append(img)
    return resized_images

def create_image_grid(image_list, max_images_per_row):
    total_images = len(image_list)
    rows = math.ceil(total_images / max_images_per_row)

    # Calculate grid dimensions
    grid_height = image_list[0].height * rows
    grid_width = image_list[0].width * max_images_per_row

    # Create an empty grid
    grid = Image.new('L', (grid_width, grid_height), 255)

    y_offset = 0
    for i in range(0, total_images, max_images_per_row):
        row_images = image_list[i:i+max_images_per_row]
        x_offset = (grid_width - sum([img.width for img in row_images])) // 2  # Center images

        max_height = max([img.height for img in row_images])

        for img in row_images:
            grid.paste(img, (x_offset, y_offset))
            x_offset += img.width

        y_offset += max_height

    return grid

def mrsa_angle(x, y):
    _x = np.mean(x)
    _y = np.mean(y)
    xdiff = x - _x
    ydiff = y - _y
    mrsa_value = (1 / np.pi) * np.arccos((xdiff.T @ ydiff) / (np.linalg.norm(xdiff) * np.linalg.norm(ydiff)))
    return mrsa_value

def find_lowest_mrsa_index(W, search_idxs, sumW):
    lowest_mrsa = float('inf')
    lowest_idx = None
    
    for idx in search_idxs:
        current_mrsa = mrsa_angle(sumW, W[:, idx:idx+1])
        if current_mrsa < lowest_mrsa:
            lowest_mrsa = current_mrsa
            lowest_idx = idx
                
    return lowest_idx, lowest_mrsa

In [None]:
data_filepath = '../datasets/jasper_full.npz'
ini_filepath = '../saved_models/jasper_full/r{}_ini.npz'
save_filepath = '../saved_models/jasper_full/r{}_l{}_g{}_it{}.npz'
image_filepath = '../images/jasper_full/r{}_l{}_g{}_it{}.jpg'

In [None]:
M = np.load(data_filepath)['X']
M = M.astype(np.float64)
m, n = M.shape

In [None]:
r = 20
iters = 1000
_lam = 1000000
_gamma = 1000000

In [None]:
unique_idxs = [6, 8, 14]

In [None]:
H3d = H.reshape(r, 100, 100, order='F') 
sumH = H3d.sum(axis=0)

search_idxs = [x for x in range(r) if x not in unique_idxs]

In [None]:
for i in range(r):
    plt.imshow(H3d[i, :, :], cmap='gray')
    plt.colorbar()
    
    a = plt.gca()
    xax = a.axes.get_xaxis()
    xax = xax.set_visible(False)
    yax = a.axes.get_yaxis()
    yax = yax.set_visible(False)

    plt.savefig(f'../images/jasper_full/jasper_full_c{i}.jpg', bbox_inches='tight')
    plt.close()

In [None]:
sumH2 = H3d[search_idxs].sum(axis=0)


# plt.imshow(sumH2, vmin=0, vmax=pmax, cmap='gray')
plt.imshow(sumH2, cmap='gray')
plt.colorbar()

a = plt.gca()
xax = a.axes.get_xaxis()
xax = xax.set_visible(False)
yax = a.axes.get_yaxis()
yax = yax.set_visible(False)

# plt.savefig(f'../images/jasper_full/nonunique.jpg', bbox_inches='tight')
# plt.close()


In [None]:
# plt.imshow(sumH, vmin=0, vmax=pmax, cmap='gray')
plt.imshow(sumH, cmap='gray')
plt.colorbar()

a = plt.gca()
xax = a.axes.get_xaxis()
xax = xax.set_visible(False)
yax = a.axes.get_yaxis()
yax = yax.set_visible(False)

plt.savefig(f'../images/jasper_full/summed.jpg', bbox_inches='tight')
plt.close()

In [None]:
image_filenames = [f'../images/jasper_full/jasper_full_c{idx}.jpg' for idx in range(r) if idx not in unique_idxs]
target_size = (500, 500)  # Adjust the size as needed
max_images_per_row = 5

resized_images = resize_images(image_filenames, target_size)
grid_image = create_image_grid(resized_images, max_images_per_row)
grid_image.show()
grid_image.save('../images/jasper_full/nonunique_grid.jpg')

In [None]:
image_filenames = ['../images/jasper_full/nonunique.jpg'] + [f'../images/jasper_full/jasper_full_c{idx}.jpg' for idx in unique_idxs]
target_size = (500, 500)  # Adjust the size as needed
max_images_per_row = 4

resized_images = resize_images(image_filenames, target_size)
grid_image = create_image_grid(resized_images, max_images_per_row)
grid_image.show()
grid_image.save('../images/jasper_full/level_2_grid.jpg')

In [None]:
resized_images = resize_images(['../images/jasper_full/summed.jpg'], (550, 500))
resized_images[0].show()
resized_images[0].save('../images/jasper_full/summed.jpg')


In [None]:
image_filenames = [f'../images/jasper_full/jasper_full_c{idx}.jpg' for idx in [0, 1, 4, 7, 11, 18]]
target_size = (500, 500)  # Adjust the size as needed
max_images_per_row = 2

resized_images = resize_images(image_filenames, target_size)
grid_image = create_image_grid(resized_images, max_images_per_row)
grid_image.show()
grid_image.save('../images/jasper_full/cluster_1.jpg')

In [None]:
image_filenames = [f'../images/jasper_full/jasper_full_c{idx}.jpg' for idx in [2, 3, 5, 9, 10, 12, 13, 16]]
target_size = (500, 500)  # Adjust the size as needed
max_images_per_row = 3

resized_images = resize_images(image_filenames, target_size)
grid_image = create_image_grid(resized_images, max_images_per_row)
grid_image.show()
grid_image.save('../images/jasper_full/cluster_2.jpg')

In [None]:
# image_filenames = [f'../images/jasper_full/jasper_full_c{idx}.jpg' for idx in [11, 15]]
# target_size = (500, 500)  # Adjust the size as needed
# max_images_per_row = 1

# resized_images = resize_images(image_filenames, target_size)
# grid_image = create_image_grid(resized_images, max_images_per_row)
# grid_image.show()
# grid_image.save('../images/jasper_full/cluster_3.jpg')

In [None]:
# image_filenames = [f'../images/jasper_full/jasper_full_c{idx}.jpg' for idx in [15, 17, 19]]
# target_size = (500, 500)  # Adjust the size as needed
# max_images_per_row = 1

# resized_images = resize_images(image_filenames, target_size)
# grid_image = create_image_grid(resized_images, max_images_per_row)
# grid_image.show()
# grid_image.save('../images/jasper_full/cluster_3.jpg')

In [None]:
# i = 1
# search_idxs = [x for x in range(r) if x not in unique_idxs]

# while len(search_idxs) > 0:
#     curr_idx, curr_mrsa = find_lowest_mrsa_index(W, search_idxs, W[:, search_idxs].sum(axis=1, keepdims=True))
#     print(curr_idx, curr_mrsa)
#     plt.imshow(H3d[curr_idx], cmap='gray')
#     plt.colorbar()
#     a = plt.gca()
#     xax = a.axes.get_xaxis()
#     xax = xax.set_visible(False)
#     yax = a.axes.get_yaxis()
#     yax = yax.set_visible(False)

#     filename = f'../images/jasper_full/l{i}_ex_{curr_idx}.jpg'
#     plt.savefig(filename, bbox_inches='tight')
#     plt.close()

#     img = Image.open(filename)
#     img = img.resize((500, 500), Image.ANTIALIAS)
#     img.save(filename)        


#     search_idxs.remove(curr_idx)

#     plt.imshow(H3d[search_idxs].sum(axis=0), cmap='gray')
#     plt.colorbar()
#     a = plt.gca()
#     xax = a.axes.get_xaxis()
#     xax = xax.set_visible(False)
#     yax = a.axes.get_yaxis()
#     yax = yax.set_visible(False)

#     filename = f'../images/jasper_full/l{i}_summed.jpg'
#     plt.savefig(filename, bbox_inches='tight')
#     plt.close()

#     img = Image.open(filename)
#     img = img.resize((500, 500), Image.ANTIALIAS)
#     img.save(filename)   

#     i += 1