In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import cv2
import os
from google.colab.patches import cv2_imshow
from torchvision import datasets, transforms
from skimage import color
from skimage import io
from scipy import spatial

In [None]:
# use Colab GPU (ensure runtime type is set to GPU)

if torch.cuda.is_available():
  dev = "cuda:0"
else:
  dev = "cpu"

device = torch.device(dev)

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 4, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(4, 8, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, 7)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(16, 8, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 4, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(4, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
        self.linear1 = nn.Linear(17024, 50)
        self.linear2 = nn.Linear(50, 17024)

    def forward(self, x):

        enc_out = self.encoder(x)
        x = torch.flatten(enc_out, 1)

        embedding = self.linear1(x)
        x = self.linear2(embedding)

        reconstruction = torch.reshape(x, enc_out.shape)
        reconstruction = self.decoder(reconstruction)
        return reconstruction

    def get_embedding(self, x):

        with torch.no_grad():
          enc_out = self.encoder(x)
          x = torch.flatten(enc_out, 1)
          embedding = self.linear1(x)
          
          return embedding

In [None]:
def train(model, data, num_epochs=5, batch_size=8, learning_rate=0.0005):
    torch.manual_seed(42)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate, 
                                 weight_decay=1e-5)
    outputs = []
    batches = torch.split(data, batch_size, dim=0)

    for epoch in range(num_epochs):
      for batch in batches:
        img = data
        recon = model(img)
        loss = criterion(recon, img)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

      if epoch % 10 == 0:
        print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss)))
      outputs.append((epoch, img, recon),)
    return outputs

In [None]:
# set path to folder containing word images

path = ''
imgs = os.listdir(path)

# load list of words

words_npy = ''
words = list(np.load(words_npy))
words = sorted(words, key=str.casefold)

In [None]:
# load images

get_image = color.rgb2gray(io.imread(path + '/{}.png'.format('0')))
images = np.zeros((100, 1, get_image.shape[0], get_image.shape[1]))

for i, filename in enumerate(os.listdir(path)):
    word = filename[:-4]
    img = color.rgb2gray(io.imread(path + '/{}.png'.format(word)))
    images[i, 0, :, :] = img / 255

images = torch.Tensor(images)

In [None]:
# load model and train

model = Autoencoder().to(device)
images = images.to(device)
max_epochs = 1000
outputs = train(model, images, num_epochs=max_epochs)

In [None]:
# visualise reconstructions

for k in range(0, max_epochs, 100):
    plt.figure(figsize=(36, 8))
    imgs = outputs[k][1].cpu().numpy()
    recon = outputs[k][2].cpu().detach().numpy()
        
    for i, item in enumerate(recon):
        if i >= 5: break
        plt.subplot(2, 5, 5+i+1)
        plt.title(k)
        plt.imshow(item[0], cmap='gray')

In [None]:
# get learned embeddings

model.eval()
embeddings_out = model.get_embedding(images)
embeddings = embeddings_out.cpu().numpy()

In [None]:
# calculate similarity matrix and save

num_words = len(words)

sim = np.zeros((num_words, num_words))
for i in range(num_words):
  for j in range(num_words):
    a = embeddings[i, :]
    b = embeddings[j, :]
    sim[i, j] = 1 - spatial.distance.cosine(a, b)

np.save('weight_matrix_CONVAUTO.npy', sim)