In [6]:
import numpy as np
import torch
import atomai as aai
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import tarfile
import os

# --- 1. PERFECT CRYSTAL IMAGE GENERATION (REPLACES IMAGE LOADING) ---

def generate_perfect_crystal(size=512, period_x=16, period_y=16):
    """Generates a synthetic image of a perfect 2D crystal lattice."""
    x = np.arange(size)
    y = np.arange(size)
    X, Y = np.meshgrid(x, y)

    # Simple 2D sinusoidal wave for atomic columns
    image = 0.5 * (np.cos(2 * np.pi * X / period_x) + 1)
    image += 0.5 * (np.cos(2 * np.pi * Y / period_y) + 1)

    # Normalize and add small noise
    image = (image - image.min()) / (image.max() - image.min())
    noise = np.random.normal(0, 0.05, image.shape)
    image = np.clip(image + noise, 0, 1)

    return image.astype(np.float32)

# Generate the image
crystal_image = generate_perfect_crystal()

# Display the generated image
plt.figure(figsize=(6, 6))
plt.imshow(crystal_image, cmap='gray')
plt.title('Synthetic Perfect Crystal Image (Input for Segmentation)')
plt.axis('off')
plt.show()

ModuleNotFoundError: No module named 'numpy'

In [None]:
# --- USER INPUT REQUIRED: Confirm the path to your file ---
TAR_FILE_PATH = 'G_MD.tar'
# Set the directory for extraction
EXTRACTION_DIR = './G_MD_extracted'
# -------------------------

# Ensure extraction directory exists
os.makedirs(EXTRACTION_DIR, exist_ok=True)
CUSTOM_MODEL_PATH = None

# Check if the file exists
if not os.path.exists(TAR_FILE_PATH):
    print(f"Error: The file was not found at {TAR_FILE_PATH}. Please check the path.")

else:
    # 1. Extract the .tar file
    try:
        with tarfile.open(TAR_FILE_PATH, 'r') as tar:
            tar.extractall(path=EXTRACTION_DIR)
        print(f"Successfully extracted {TAR_FILE_PATH} to {EXTRACTION_DIR}")

        # 2. Automatically find the model file inside the extracted folder
        for root, _, files in os.walk(EXTRACTION_DIR):
            for file in files:
                if file.endswith(('.pt', '.pth', 'model.h5')):
                    CUSTOM_MODEL_PATH = os.path.join(root, file)
                    print(f"Found model file: {CUSTOM_MODEL_PATH}")
                    break
            if CUSTOM_MODEL_PATH:
                break

        if not CUSTOM_MODEL_PATH:
            print("Error: Could not find a common PyTorch model file (.pt, .pth) inside the archive.")

    except Exception as e:
        print(f"An error occurred during extraction: {e}")

# 3. Load the model
m_sem = None
if CUSTOM_MODEL_PATH and os.path.exists(CUSTOM_MODEL_PATH):
    try:
        m_sem = aai.models.load_model(CUSTOM_MODEL_PATH, full_path=True)
        print("Successfully loaded model with custom retrained weights.")
    except Exception as e:
        print(f"Error loading model with atomai: {e}")
        print("This might happen if the model file is not an AtomAI compatible PyTorch save.")

In [None]:
if m_sem is not None:
    # Prepare the image tensor: (1, 1, H, W)
    X_tensor = torch.from_numpy(crystal_image[None, None]).float()

    # --- 1. Run Segmentation and Atom Localization ---
    # The 'atom_find' method runs the segmentation and then localizes the atom coordinates
    # coordinates: dict {0: numpy array (N atoms x 3: [x, y, class])}
    nn_output, coordinates = m_sem.predict(X_tensor, method='atom_find')

    # Extract the coordinates array for the single image
    coords_array = coordinates[0]

    # Separate the x, y coordinates and the predicted class
    x_coords = coords_array[:, 0]
    y_coords = coords_array[:, 1]
    classes = coords_array[:, 2]

    print(f"\nSegmentation complete. Found {len(coords_array)} atomic columns.")

    # --- 2. Visualization ---

    # Define a default set of colors for the classes (e.g., 1, 2, 3...)
    class_colors = {
        1: 'red',
        2: 'blue',
        3: 'green',
        4: 'yellow'
    }

    # Create the segmented image mask from the raw network output
    segmented_image = np.argmax(nn_output[0, 0], axis=0)

    # Prepare colormap for the mask (background is class 0, usually black)
    colors = ['k'] + [class_colors.get(i, 'cyan') for i in range(1, segmented_image.max() + 1)]
    cmap = ListedColormap(colors)

    fig, ax = plt.subplots(1, 2, figsize=(14, 7))

    # Left: Segmentation Mask
    ax[0].imshow(crystal_image, cmap='gray')
    ax[0].imshow(segmented_image, cmap=cmap, alpha=0.5)
    ax[0].set_title('Semantic Segmentation Mask')
    ax[0].axis('off')

    # Right: Coordinate Plot over Input Image
    ax[1].imshow(crystal_image, cmap='gray')

    # Overlay the found coordinates
    unique_classes = np.unique(classes).astype(int)
    for class_id in unique_classes:
        if class_id == 0: continue # Skip background class

        mask = (classes == class_id)
        ax[1].scatter(
            x_coords[mask],
            y_coords[mask],
            s=40,
            c=class_colors.get(class_id, 'white'), # Use defined color, or default to white
            marker='o',
            edgecolor='yellow',
            linewidth=1,
            label=f'Atom Class {class_id}'
        )

    ax[1].set_title('Extracted Coordinates for Defect Twin')
    ax[1].axis('off')
    ax[1].legend()

    plt.tight_layout()
    plt.show()

    # --- Output for Defect Twin ---
    print("\n\n--- Final Output: Coordinates Array for Defect Twin ---")
    print(f"Coordinate Array (x, y, class) Shape: {coords_array.shape}")
    print("First 5 coordinates:")
    print(coords_array[:5])
    #
    print("\nUse the 'coords_array' as the ideal input for your defect generation twin.")

else:
    print("Cannot proceed with segmentation: Model loading failed in the previous step. Please check Cell 2.")