# Segmentation using pretrained G_MD.tar

This notebook runs a pretrained AtomAI model archive (`G_MD.tar`) on an input STEM image and extracts atomic coordinates.

Colab notes: mount Google Drive or upload the `G_MD.tar` archive. Install requirements in the first cell and run top-to-bottom. If no servers are available, a synthetic perfect-crystal image will be generated.

In [None]:
# Install required packages (runs quickly if already present)
# AtomAI is installed from GitHub; matplotlib/pillow for plotting.
!pip install --quiet git+https://github.com/pycroscopy/atomai
!pip install --quiet matplotlib pillow
print('Installed/ensured atomai, matplotlib, pillow')

In [None]:
# Basic imports
import os
import tempfile
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from PIL import Image
import torch

# Keep the outputs directory consistent
OUT_DIR = 'notebooks/output'
os.makedirs(OUT_DIR, exist_ok=True)
print('OUT_DIR =', OUT_DIR)

In [None]:
# Try to get a simulated STEM image via NotebookClient (same approach as Aberrations.ipynb).
# If the servers are not running, fall back to a synthetic generator.
img = None
try:
    from asyncroscopy.clients.notebook_client import NotebookClient
    tem = NotebookClient.connect(host='localhost', port=9000)
    image_args = {'scanning_detector': 'HAADF', 'size': 512, 'dwell_time': 10e-6}
    img = tem.send_command('AS', 'get_scanned_image', image_args)
    print('Obtained image from AS_server_SimAtomRes via NotebookClient')
except Exception as e:
    print('Could not get image from servers â€” falling back to synthetic generator. Error:', e)
    def generate_perfect_crystal(size=512, period_x=16, period_y=16, seed=0):
        import numpy as _np
        rng = _np.random.default_rng(seed)
        x = _np.arange(size)
        y = _np.arange(size)
        X, Y = _np.meshgrid(x, y)
        image = 0.5 * (_np.cos(2 * _np.pi * X / period_x) + 1)
        image += 0.5 * (_np.cos(2 * _np.pi * Y / period_y) + 1)
        image = (image - image.min()) / (image.max() - image.min())
        noise = rng.normal(0, 0.05, image.shape)
        image = _np.clip(image + noise, 0, 1)
        return image.astype(_np.float32)
    img = generate_perfect_crystal()

# Normalize and ensure 2D float array
if isinstance(img, Image.Image):
    img = np.array(img.convert('F'), dtype=np.float32)
img = np.asarray(img, dtype=np.float32)
if img.ndim == 3 and img.shape[0] == 1:
    img = img[0]
elif img.ndim == 3 and img.shape[-1] == 1:
    img = img[..., 0]

plt.figure(figsize=(5,5))
plt.imshow(img, cmap='gray', origin='lower')
plt.title('Input Atomic-resolution Image')
plt.axis('off')
plt.show()

# keep 'img' variable for downstream cells

In [None]:
# Model archive location options for Colab:
# 1) Mount Google Drive and place G_MD.tar in MyDrive, or
# 2) Upload the file manually using the upload widget below.

# Try Google Drive first (optional). If you are NOT using Drive, skip mounting.
use_drive = True
model_tar_path = '/content/G_MD.tar'
if use_drive:
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        drive_path = '/content/drive/MyDrive/G_MD.tar'
        if os.path.exists(drive_path):
            model_tar_path = drive_path
            print('Found model in Google Drive at', drive_path)
        else:
            print('Did not find model at', drive_path, '\nYou can upload it with the cell below or set use_drive=False to skip Drive mounting.')
    except Exception as e:
        print('Drive mount failed or not in Colab environment:', e)

# Upload helper (runs in Colab) - will write to /content/G_MD.tar
if not os.path.exists(model_tar_path):
    try:
        from google.colab import files
        print('Please upload G_MD.tar via the file picker...')
        uploaded = files.upload()
        for name in uploaded:
            if name.endswith('.tar') or name.endswith('.zip') or name.endswith('.tar.gz'):
                src = name
                dst = model_tar_path
                os.rename(src, dst)
                print('Saved uploaded file to', dst)
                break
    except Exception as e:
        print('Upload skipped or not running in Colab filepicker:', e)

print('Using model archive at', model_tar_path)
model_tar_path = os.path.expanduser(model_tar_path)

In [None]:
# Load atomai and attempt to load the pretrained model
import atomai as aai
model = None
# Prefer letting atomai load the archive directly (it handles several formats)
if os.path.exists(model_tar_path):
    try:
        print('Attempting to load model archive with atomai.models.load_model(...)')
        model = aai.models.load_model(model_tar_path)
        print('Loaded model from archive using atomai.models.load_model')
    except Exception as e:
        print('atomai.models.load_model failed, will try to extract weights and load manually:', e)
        # Try to extract common weight files and load state_dict if present
        try:
            import tarfile, zipfile
            tmpd = tempfile.mkdtemp(prefix='G_MD_')
            # try tar first, then zip
            try:
                with tarfile.open(model_tar_path, 'r') as tar:
                    tar.extractall(path=tmpd)
            except Exception:
                try:
                    with zipfile.ZipFile(model_tar_path, 'r') as zf:
                        zf.extractall(tmpd)
                except Exception as e2:
                    print('Failed to extract archive:', e2)
            # search for .pt/.pth files
            model_file = None
            for root, _, files in os.walk(tmpd):
                for f in files:
                    if f.endswith(('.pt', '.pth')):
                        model_file = os.path.join(root, f)
                        break
                if model_file:
                    break
            if model_file:
                try:
                    state = torch.load(model_file, map_location='cpu')
                    if isinstance(state, dict) and 'state_dict' in state:
                        state_dict = state['state_dict']
                    else:
                        state_dict = state
                    model = aai.models.Segmentor(nb_classes=3)
                    model.net.load_state_dict(state_dict)
                    print('Loaded weights into a Segmentor instance from', model_file)
                except Exception as e3:
                    print('Failed to load state_dict from', model_file, e3)
        except Exception as e_extract:
            print('Archive fallback/load failed:', e_extract)

if model is None:
    print('Falling back to an untrained Segmentor (results may not be meaningful)')
    model = aai.models.Segmentor(nb_classes=3)

# model is ready (may be untrained if loading failed)

In [None]:
# Prepare tensor and run prediction (robust to different nn_output shapes)
X = torch.from_numpy(img[None, None, :, :]).float()
coords = None
nn_output = None
try:
    # prefer atom_find to get coordinates and segmentation map
    nn_output, coords = model.predict(X, method='atom_find')
    print('atom_find prediction returned types:', type(nn_output), type(coords))
except Exception as e:
    print('atom_find failed, trying generic predict:', e)
    try:
        nn_output = model.predict(X)
    except Exception as e2:
        print('predict failed, will fallback to simple threshold segmentation:', e2)

# Helper to coerce nn_output into a 2D integer segmentation mask
def nn_output_to_mask(nn_out):
    if nn_out is None:
        return (img > img.mean()).astype(np.int32)
    a = np.asarray(nn_out)
    # common shapes: (1, H, W, C), (1, C, H, W), (C, H, W), (H, W)
    if a.ndim == 4 and a.shape[0] == 1 and (a.shape[-1] == 1 or a.shape[-1] > 1):
        # (1, H, W, C) or (1, H, W, 1) -> take channel axis last
        arr = a[0]
        if arr.ndim == 3:
            if arr.shape[-1] == 1:
                return arr[...,0].astype(np.int32)
            return np.argmax(arr, axis=-1).astype(np.int32)
    if a.ndim == 4 and a.shape[1] > 1:
        # (1, C, H, W) -> squeeze and argmax over channel dim 0
        arr = a.squeeze(0)
        return np.argmax(arr, axis=0).astype(np.int32)
    if a.ndim == 3:
        # (C, H, W) or (H, W, C)
        if a.shape[0] in (1,2,3,4):
            return np.argmax(a, axis=0).astype(np.int32)
        else:
            return np.argmax(a, axis=-1).astype(np.int32)
    if a.ndim == 2:
        return a.astype(np.int32)
    # fallback
    return (img > img.mean()).astype(np.int32)

segmented = nn_output_to_mask(nn_output)
print('Segmented mask shape:', segmented.shape)

In [None]:
# Visualize and save outputs
fig, ax = plt.subplots(1,2, figsize=(14,6))
ax[0].imshow(img, cmap='gray', origin='lower')
cmap_colors = ['k','red','blue','green','yellow']
cmap = ListedColormap(cmap_colors[: max(2, int(segmented.max())+1)])
ax[0].imshow(segmented, cmap=cmap, alpha=0.5, origin='lower')
ax[0].set_title('Segmentation Overlay')
ax[0].axis('off')

ax[1].imshow(img, cmap='gray', origin='lower')
if coords is not None:
    try:
        coords_arr = coords[0] if isinstance(coords, dict) and 0 in coords else coords
        coords_arr = np.asarray(coords_arr)
        if coords_arr.ndim == 2 and coords_arr.shape[1] >= 2:
            x = coords_arr[:,0]
            y = coords_arr[:,1]
            classes = coords_arr[:,2].astype(int) if coords_arr.shape[1] > 2 else np.zeros(len(x), dtype=int)
            colors = {1:'red', 2:'blue', 3:'green'}
            for cl in np.unique(classes):
                if cl == 0:
                    continue
                mask = classes == cl
                ax[1].scatter(x[mask], y[mask], s=30, c=colors.get(cl,'white'), label=f'Class {cl}', edgecolor='yellow')
            ax[1].legend()
    except Exception as e:
        print('Failed to unpack coords for plotting:', e)

ax[1].set_title('Detected Atom Coordinates')
ax[1].axis('off')
plt.show()

# Save artifacts
np.save(os.path.join(OUT_DIR, 'segmented_mask.npy'), segmented)
fig.savefig(os.path.join(OUT_DIR, 'segmentation_and_coords.png'), bbox_inches='tight', dpi=200)
print('Saved segmented_mask.npy and segmentation_and_coords.png to', OUT_DIR)

# If coords available, save CSV
if coords is not None:
    try:
        coords_arr = coords[0] if isinstance(coords, dict) and 0 in coords else coords
        coords_arr = np.asarray(coords_arr)
        import csv
        csv_path = os.path.join(OUT_DIR, 'coords.csv')
        with open(csv_path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['x','y','class'])
            for row in coords_arr:
                writer.writerow([float(row[0]), float(row[1]), int(row[2]) if len(row)>2 else 0])
        print('Saved coordinates CSV to', csv_path)
    except Exception as e:
        print('Failed to save coords CSV:', e)