# Use pretrained models for clustering

Tired of training models? Just download a pretrained one!

Pretrained models for different seismology tasks can be found for instance [here](https://github.com/seisbench/seisbench), but let's do something different.

An interesting (and recently popular) method of clustering similar waveforms is to convert them into an image, and then use computer vision models, which are highly advanced and are easily downloaded with pretrained weights. And converting waveforms to images is something we know how to do -- we can compute spectrograms.

Now computer vision models are typically trained to classify different objects, and not look at spectrograms. We'll do a little trick to use their pattern recognition abilities to cluster images instead. But first, some imports:

In [None]:
from pathlib import Path
import numpy as np
import h5py
import obspy
import scipy.signal
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

## Download data

In [None]:
! wget https://storage.googleapis.com/norsar-ml-ws/events_classification_Zonly_TRAIN.h5

## Create spectrograms

Here we can again make use of the nice Obspy library. For visualisation, we can plot the spectrogram for an event along with the original waveform.

In [None]:
with h5py.File('events_classification_Zonly_TRAIN.h5', "r") as fin:
    waveforms = fin.get('waveforms')[:100]
    event_types = fin.get('type')[:]

event_index = 1
sampling_freq = 100

plt.figure(figsize=(10, 4))
time_axis = np.arange(len(waveforms[event_index, :, 0])) / sampling_freq  # Convert samples to time
plt.plot(time_axis, waveforms[event_index, :, 0], 'b-', linewidth=0.8)
plt.xlabel('Time (seconds)')
plt.ylabel('Amplitude')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
from obspy.imaging.spectrogram import spectrogram 

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), 
                               gridspec_kw={'height_ratios': [2, 1]}, 
                               sharex=True)

# Spectrogram
ax1.set_ylabel('Frequency (Hz)')
spectrogram(waveforms[event_index, :, 0], sampling_freq, axes=ax1, show=False)

# Waveform 
time_axis = np.arange(len(waveforms[event_index, :, 0])) / sampling_freq  # Convert samples to time
ax2.plot(time_axis, waveforms[event_index, :, 0], 'b-', linewidth=0.8)
ax2.set_xlabel('Time (seconds)')
ax2.set_ylabel('Amplitude')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


plt.show()

## Write spectrograms to PNG files

The computer vision models we'll use come with some nice convenience functions for reading in image files, so let's first write our spectrograms to PNG files.

In [None]:
def write_spectrogram(waveform, fs, filename):

    fig = spectrogram(waveform, fs, show=False)
    for ax in fig.get_axes():
        ax.set_ylim(0.1, 25)
        ax.set_xlim(1, 59)
        ax.set_axis_off()

    fig.savefig(filename, bbox_inches='tight', pad_inches=0)
    plt.close()

In [None]:
image_dir = Path('spectrograms')
image_dir.mkdir(exist_ok=True)

for index in range(len(waveforms)):

    if index % 10 == 0:
        print('Computing spectrogram for event', index)

    write_spectrogram(
        waveforms[index, :, 0],
        sampling_freq,
        Path(image_dir / f'event_{index}.png')
    )

Which files do we now have?

In [None]:
files = list(image_dir.iterdir())
print(files)

In [None]:
# Show that we can read back the image. 

img = keras.utils.load_img(files[0], target_size=(224, 224))
x = keras.utils.img_to_array(img)
x /= 255.
#x = np.expand_dims(x, axis=0)

plt.imshow(x)

## Download a pretrained computer vision model

With Keras it's easy to download pretrained models, we just import the one we want, and then call it. We can start with one named **ResNet50**.

In [None]:
from keras.applications.resnet50 import ResNet50
from keras.applications.resnet50 import preprocess_input, decode_predictions

Now here comes the trick: Instead of using the full model, which is trained to do classification, we run it _without_ the last classification layer. Then we just get out the final features detected in the images, but without any classification step applied. 

In [None]:
model = ResNet50(weights='imagenet', include_top=False, pooling='avg')

In [None]:
# Process all the images to extract features

images = []
preds = []
for fin in files:
    img = keras.utils.load_img(fin, target_size=(224, 224))
    x = keras.utils.img_to_array(img)

    images.append(x / 255.)
    
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    preds.append(
        model.predict(x, verbose=0)
    )

images = np.stack(images, axis=0)
preds = np.vstack(preds)

print('images.shape:', images.shape)
print('preds.shape:', preds.shape)

## Make use of the output

Okay! From the `shapes` we printed above, we see that we for the 100 images, get a list of 2048 numbers. A 2048-demensional space is hard to visualise, but we can use a dimensionality reduction technique to plot it in 2D.

In [None]:
from sklearn.manifold import TSNE

(If not running in Colab, this need `libopenblas-dev`.) 

In [None]:
reducer = TSNE()

embedding = reducer.fit_transform(preds)

Let's plot the so-called embeddings of our events:

In [None]:
plt.scatter(embedding[:, 0], embedding[:, 1])

In [None]:
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

def plot_tsne_with_thumbnails(tsne_coords, images, indices_to_show=None, 
                              thumbnail_size=0.2, figsize=(10, 7)):

    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot all points as scatter plot
    ax.scatter(tsne_coords[:, 0], tsne_coords[:, 1], s=30)
    
    for idx in indices_to_show:
        img = images[idx]
        x, y = tsne_coords[idx]

        imagebox = OffsetImage(img, zoom=thumbnail_size)
        
        ab = AnnotationBbox(imagebox, (x, y), frameon=False, pad=0)
        ax.add_artist(ab)
    
    ax.set_xlabel('t-SNE 1')
    ax.set_ylabel('t-SNE 2')
    
    plt.tight_layout()
    return fig, ax

# Select random indices to show thumbnails
n_thumbnails = 10
indices_to_show = np.random.choice(len(images), n_thumbnails, replace=False)

fig, ax = plot_tsne_with_thumbnails(
    embedding, 
    images, 
    indices_to_show=indices_to_show,
)

plt.show()