In [None]:
import torch
import timm
from torchvision import transforms
from cryocat import cryomap
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.backends.backend_pdf import PdfPages  # Import to handle saving multiple figures to a PDF

INPUT_TS = 'input_TS.mrc'
CLEANED_TS = 'cleaned_TS.mrc'
ANGLE_START = -50
ANGLE_STEP = 2
PDF_OUTPUT = 'output_visualization.pdf'
MODEL = 'models/swin_tiny_fine_tuned.pth'

# Set device to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if 'swin_tiny' in MODEL:
    model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=2)
elif 'swin_large' in MODEL:
    model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=2)
else:
    raise ValueError("MODEL file must contain 'swin_tiny' or 'swin_large'")

# Load the model's state_dict
model.load_state_dict(torch.load(MODEL))
model = model.to(device)
model.eval()

# Define image transformation (match your test_transforms)
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def evaluate_single_image(image_input, index, class_0_info, class_1_info):
    # Load and preprocess the image
    if isinstance(image_input, str):
        image = Image.open(image_input).convert("RGB")
    elif isinstance(image_input, Image.Image):
        image = image_input
    image = image_transforms(image).unsqueeze(0).to(device)

    # Forward pass through the model
    with torch.no_grad():
        output = model(image)
        probabilities = torch.softmax(output, dim=1).cpu().numpy()[0]

    # Display probabilities and classification
    predicted_class = np.argmax(probabilities)

    if predicted_class == 0:
        class_0_info.append((index, probabilities[0]))  # Append index and probability for class 0
    else:
        class_1_info.append((index, probabilities[1]))  # Append index and probability for class 1

    return predicted_class


# Initialize variables
mrc = cryomap.read(f"{INPUT_TS}")

tomo3d = []
class_0_info = []  # To store (index, probability) for class 0
class_1_info = []  # To store (index, probability) for class 1

# Plot setup
prev_text = False

# Create a PDF to store all figures
with PdfPages(PDF_OUTPUT) as pdf:

    # First figure (Tilt Angle Visualization)
    fig = plt.figure(figsize=(5, 5))
    plt.axis('off')

    # Evaluate selected slices
    for i in range(0, mrc.shape[2]):  
        angle = ANGLE_START + i * ANGLE_STEP  # Increment the angle by angle_step for each slice
        
        image_b16 = cryomap.scale(mrc[:, :, i], 0.0625)
        image_b16 = ((image_b16 - image_b16.min()) * (1 / (image_b16.max() - image_b16.min()) * 255)).astype('uint8')
        image_b16 = Image.fromarray(image_b16)
        
        if image_b16.mode != 'RGB':
            image_b16 = image_b16.convert('RGB')

        correct_tilt = evaluate_single_image(image_b16, i, class_0_info, class_1_info)
        angle = np.radians(angle)   # Convert angle to radians for plotting
        if correct_tilt:
            tomo3d.append(mrc[:, :, i])
            plt.plot([-np.cos(angle), np.cos(angle)], [-np.sin(angle), np.sin(angle)], color='black', linewidth=1)
            prev_text = False
        else:
            plt.plot([-np.cos(angle), np.cos(angle)], [-np.sin(angle), np.sin(angle)], color='red', linewidth=1, linestyle='--')
            if not prev_text:
                plt.text(np.cos(angle) * 1.01, np.sin(angle) * 1.09, str(i+1), fontsize=12, color='red')
            prev_text = not prev_text

    # Add caption to the first page
    fig.text(0.5, 0.95, "Tilt Angle Visualization", ha='center', fontsize=14, weight='bold')

    # Save the first figure to the PDF
    pdf.savefig()
    plt.close()

    # Second figure (Images with Probability Scale Bar)
    num_images = len(class_0_info)
    cols = 3  # Number of columns in the grid layout
    rows = (num_images // cols) + (num_images % cols > 0)

    # Create a new figure for images with probability scale bar
    fig, axes = plt.subplots(rows, cols, figsize=(10, rows * 3))
    axes = axes.flatten()  # Flatten to iterate easily

    fig.subplots_adjust(top=0.8, hspace=0.5, wspace=0.5)

    for i, (index, prob) in enumerate(class_0_info):
        # Load the corresponding image_b16 image
        image_b16 = cryomap.scale(mrc[:, :, index], 0.0625)
        image_b16 = ((image_b16 - image_b16.min()) * (1 / (image_b16.max() - image_b16.min()) * 255)).astype('uint8')
        image_b16 = Image.fromarray(image_b16)

        # Display image on subplot
        ax = axes[i]
        ax.imshow(image_b16, cmap='gray')
        ax.axis('off')  # Remove axis for cleaner presentation

        colors = ['red'] * int(prob * 100) + ['black'] * int((1 - prob) * 100)
        discrete_cmap = ListedColormap(colors)

        # Add probability scale bar next to each image
        cbar = fig.colorbar(
            plt.cm.ScalarMappable(cmap=discrete_cmap, norm=plt.Normalize(vmin=0, vmax=1)),
            ax=ax, orientation='vertical', fraction=0.046, pad=0.04
        )
        cbar.set_ticks([0, 0.5, 1])
        cbar.set_ticklabels([f'{int(t * 100)}%' for t in [0, 0.5, 1]])

        # Title showing image index and probability
        ax.set_title(f"Index: {index+1} | Prob: {prob:.2%}")

    # Remove any unused subplots
    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])

    # Add caption to the second page
    fig.text(0.5, 0.9, "Excluded Tilt Images with Probability Scale Bar", ha='center', fontsize=14, weight='bold')

    # Save the second figure to the PDF
    pdf.savefig()
    plt.close()

# Stack the tomo3d list to create a 3D volume and save it
tomo3d = np.stack(tomo3d, axis=2)
cryomap.write(tomo3d, CLEANED_TS, data_type=np.single)


  model.load_state_dict(torch.load('/home/ms/tomajtne/iciap/models/b16_oldAnnot_TSsplit/all/swin_tiny_patch4_window7_224_alltrain_50e.pth'))
