<a href="https://colab.research.google.com/github/PacktPublishing/Hands-On-Computer-Vision-with-PyTorch/blob/master/Chapter11/conv_auto_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Quelle: V Kishore Ayyadevara and Yeshwanth Reddy, Modern Computer Vision with PyTorch

In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.optim import Adam
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
import random
from sklearn.manifold import TSNE

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print( f"device: {device}" )

In [None]:
img_transform =  torchvision.transforms.Compose([
     torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize([0.5], [0.5]),
     torchvision.transforms.Lambda(lambda x: x.to(device))
])

trn_ds = torchvision.datasets.MNIST('data/', transform=img_transform, train=True, download=True)
test_ds = torchvision.datasets.MNIST('data/', transform=img_transform, train=False, download=True)

batch_size = 32
trn_dl = DataLoader(trn_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

In [None]:
class ConvAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=3, padding=1), nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2), nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 5, stride=3, padding=1), nn.ReLU(True),
            nn.ConvTranspose2d(16, 1, 2, stride=2, padding=1), nn.Tanh()
        )
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
def train_batch(input, model, criterion, optimizer):
    model.train()
    optimizer.zero_grad()
    output = model(input)
    loss = criterion(output, input)
    loss.backward()
    optimizer.step()
    return loss

In [None]:
@torch.no_grad()
def validate_batch(input, model, criterion):
    model.eval()
    output = model(input)
    loss = criterion(output, input)
    return loss

In [None]:
model = ConvAutoEncoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)

In [None]:
summary(model, input_size=(1, 28, 28))

In [None]:
num_epochs = 20

for epoch in range(num_epochs):
    N = len(trn_dl)
    lossTrainSum = 0
    for ix, (data, _) in enumerate(trn_dl):
        loss = train_batch(data, model, criterion, optimizer)
        lossTrainSum += loss
        
    M = len(test_dl)
    lossTestSum = 0
    for ix, (data, _) in enumerate(test_dl):
        loss = validate_batch(data, model, criterion)
        lossTestSum += loss

    print( f"epoch {epoch}  loss train {lossTrainSum/len(trn_dl)}  loss test {lossTestSum/len(test_dl)}" )

In [None]:
torch.save(model.state_dict(), "convAutoEncoder.pt")

In [None]:
def plotImages( img1, img2 ):
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow( img1, cmap='gray' )
    ax[0].set_title('input')
    ax[0].axis('off')
    ax[1].imshow( img2, cmap='gray' )
    ax[1].set_title('prediction')
    ax[1].axis('off')
    plt.show()

## Zufälliges Bild aus dem Trainingsdatensatz kodieren-->dekodieren

In [None]:
for i in range(3):
    ix = np.random.randint(len(trn_ds))
    im, _ = trn_ds[ix]
    _im = model(im[None])[0] # durch das [None] wird dem Bild-Tensor "im" (1x28x28) eine Extra-Dimension hinzugefügt
    plotImages( im[0].cpu().detach().numpy(), _im[0].cpu().detach().numpy() )

## Zufälliges Bild aus dem Testdatensatz kodieren-->dekodieren

In [None]:
for i in range(3):
    ix = np.random.randint(len(test_ds))
    im, _ = test_ds[ix]
    _im = model(im[None])[0] # durch das [None] wird dem Bild-Tensor "im" (1x28x28) eine Extra-Dimension hinzugefügt
    plotImages( im[0].cpu().detach().numpy(), _im[0].cpu().detach().numpy() )

In [None]:
def salt_and_pepper_noise(x):
    """adds salt and pepper noise to a (28, 28) tensor"""
    x = x.clone()
    for i in range(28):
        for j in range(28):
            if random.random() < 0.05:
                x[0, i, j] = 1.0
            elif random.random() > 0.95:
                x[0, i, j] = 0.0
    return x
    
    
# get a random image from the train set and add noise
ix = np.random.randint(len(trn_ds))
im, _ = trn_ds[ix]
im_noise = salt_and_pepper_noise(im)
# from (1, 28, 28) to (28, 28)
im_noise    = im_noise[0].cpu().detach().numpy()

In [None]:
# get mean vector over training images in latent space
# call model.encoder(x) in loop and take mean of output
# use model.encoder(x).cpu().detach().numpy() to get numpy array
# use np.mean( ... , axis=0 ) to get mean vector

# for every class in the training set, compute the mean vector in the latent space
# use model.encoder(x).cpu().detach().numpy() to get numpy array
# use np.mean( ... , axis=0 ) to get mean vector

@torch.no_grad()
def get_latent_batch(input, model):
    model.eval()
    output = model.encoder(input)
    return output

outputs = []
for ix, (data, _) in enumerate(trn_dl):
    output = get_latent_batch(data, model)
    output = output.cpu().detach().numpy()
    outputs.append(output)



In [None]:
@torch.no_grad()
def decode(input, model):
    model.eval()
    output = model.decoder(input)
    return output

fig, ax = plt.subplots(1, 10)
for i in range(10):
    # outputs is a numpy array -> make it a tensor
    im_tensor = torch.tensor(outputs[i])
    im = outputs[i].to(device)
    decoded = decode(im, model)
    ax[i].imshow(decoded, cmap='gray' )
    
plt.show()

In [None]:
@torch.no_grad()
def get_latent_batch(input, model):
    model.eval()
    output = model.encoder(input)
    return output

outputs = []
for digit in range(10):
    digit_outputs = []
    for ix, (data, labels) in enumerate(trn_dl):
        mask = labels == digit
        if not mask.any():
            continue
        output = get_latent_batch(data[mask], model)
        output = output.cpu().detach().numpy()
        digit_outputs.append(output)
    digit_outputs = np.concatenate(digit_outputs, axis=0)
    mean_vector = np.mean(digit_outputs, axis=0)
    outputs.append(mean_vector)
outputs = np.stack(outputs, axis=0)

@torch.no_grad()
def decode(input, model):
    model.eval()
    output = model.decoder(input)
    return output

fig, ax = plt.subplots(1, 10)
for i in range(10):
    im_tensor = torch.tensor(outputs[i]).to(device)
    decoded = decode(im_tensor, model)
    ax[i].imshow(decoded[0].cpu().detach().numpy(), cmap='gray' )
    
plt.show()

# choose two digits and move in latent space from one mean vector to the other
# use torch.linspace to create a vector of 10 points between the two mean vectors
# use torch.stack to create a tensor of shape (10, 10) from the 10 vectors
# use model.decoder to decode the 10 vectors
# use torchvision.utils.make_grid to create a grid of the 10 images
# use plt.imshow to plot the grid

digits = [0, 8]
digit_outputs = []
for ix, (data, labels) in enumerate(trn_dl):
    mask = labels == digits[0]
    if not mask.any():
        continue
    output = get_latent_batch(data[mask], model)
    output = output.cpu().detach().numpy()
    digit_outputs.append(output)
digit_outputs = np.concatenate(digit_outputs, axis=0)
mean_vector_0 = np.mean(digit_outputs, axis=0)

digit_outputs = []
for ix, (data, labels) in enumerate(trn_dl):
    mask = labels == digits[8]
    if not mask.any():
        continue
    output = get_latent_batch(data[mask], model)
    output = output.cpu().detach().numpy()
    digit_outputs.append(output)
digit_outputs = np.concatenate(digit_outputs, axis=0)
mean_vector_8 = np.mean(digit_outputs, axis=0)

latent_vectors = torch.linspace(mean_vector_0, mean_vector_8, 10)


In [None]:
mean_all_classes = np.mean(outputs, axis=0)

fig, ax = plt.subplots(1, 1)

im_tensor = torch.tensor(mean_all_classes).to(device)
decoded = decode(im_tensor, model)
ax.imshow(decoded[0].cpu().detach().numpy(), cmap='gray' )

zeros_like_mean = torch.zeros_like(mean_all_classes)

In [None]:
# vectors is the list of mean vectors for all classes
# find pairwise distances between all mean vectors
# use torch.norm to compute the norm of a vector
# use torch.stack to create a tensor of shape (10, 10) from the 10 vectors
# use plt.imshow to plot the 
from itertools import combinations

pairs = list(combinations(range(10), 2)) 
vectors = torch.tensor(outputs)
distances = torch.norm(vectors[:, None] - vectors[None, :], dim=2)
# find classes with largest distance
# use torch.argmax to find the index of the largest value in a tensor

largest_distance = torch.argmax(distances)

# find classes with smallest distance
# use torch.argmin to find the index of the smallest value in a tensor

print(f"largest distance: {largest_distance}")