#  **Практическое занятие №8. Text2image модели. Основные архитектуры**

In [None]:
import numpy as np

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Inception Score

Документация по Inception Score на torch https://pytorch.org/ignite/generated/ignite.metrics.InceptionScore.html

### Применяем Inception v3 к конкретному изображению

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
model.eval()

In [None]:
from PIL import Image
from torchvision import transforms, datasets
import torchshow as ts

import io
import requests
import PIL

In [None]:
def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return PIL.Image.open(io.BytesIO(resp.content))

In [None]:
def preprocess(input_image):
    transform = transforms.Compose([
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = transform(input_image)
    #print(input_tensor.shape)
    input_batch = input_tensor.unsqueeze(0)
    #print(input_batch.shape)
    return input_batch

In [None]:
image = download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg')

In [None]:
image

In [None]:
input_batch = preprocess(image)

In [None]:
input_batch.shape

In [None]:
with torch.no_grad():
    output = model(input_batch)

In [None]:
output.shape

In [None]:
probabilities = torch.nn.functional.softmax(output[0], dim=0)

In [None]:
sum(probabilities)

In [None]:
import matplotlib.pyplot as plt

plt.plot(np.arange(1000), probabilities)
plt.show

In [None]:
np.argmax(probabilities)

In [None]:
# https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

In [None]:
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

### Посчитаем Inception Score на данных из CIFAR-10 (как будто эти картинки выдал генератор)

In [None]:
cifar_transform = transforms.Compose([
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

In [None]:
from PIL import Image
from torchvision import transforms, datasets
import torchshow as ts
import torchvision

import io
import requests
import PIL

In [None]:
pip install torchvision==0.15.1

In [None]:
cifar10_testset = torchvision.datasets.CIFAR10(root='cifar', train=False, download=True, transform=cifar_transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(15,4))

for i in range(20):
    plt.subplot(2, 10, i + 1)
    plt.imshow(cifar10_testset[i][0].permute(1, 2, 0) 
              )
    plt.xticks([])
    plt.yticks([])
    plt.title(classes[cifar10_testset[i][1]])

plt.show()

In [None]:
from collections import OrderedDict

import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *
from ignite.contrib.metrics.regression import *
from ignite.contrib.metrics import *

from torch.utils.data import DataLoader

In [None]:
from ignite.engine import Engine
from ignite.metrics import InceptionScore

In [None]:
inceptor = nn.Sequential(
    model,
    nn.Softmax(1)
)

In [None]:
def process_function(engine, batch):
    # ...
    return batch

engine = Engine(process_function)
metric = InceptionScore(num_features=1000, feature_extractor=inceptor)
metric.attach(engine, "is")

dataloader = DataLoader(cifar10_testset, batch_size=100, shuffle=True)
data, labels = next(iter(dataloader))

# ...
state = engine.run([data], 1)
print(f"InceptionScore: {state.metrics['is']}")

### Автоэнкодер

In [None]:
import numpy as np
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torchshow as ts
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
mnist_train_val = datasets.MNIST(root='mnist', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(mnist_train_val, [0.8, 0.2])

In [None]:
# example
img, cls = mnist_train_val[1]
ts.show(img)

In [None]:
img.shape

In [None]:
class SimpleAutoEncoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim=1024, img_shape=(1, 28, 28)):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(np.prod(img_shape), hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, np.prod(img_shape)),
            nn.Sigmoid(),
            nn.Unflatten(1, img_shape)
        )

    def encode(self, x):
        return self.encoder(x)
    
    def forward(self, x):
        z = self.encoder(x)
        y = self.decoder(z)
        return y
    
    def process(self, x):
        return self(x)

Напишем вспомогальные функции

In [None]:
from tqdm import tqdm

# одна эпоха обучения
def run_epoch(ae, opt, loss, dataloader, is_train=True):
    ae.train(is_train)
    total_loss = 0.0
    with torch.set_grad_enabled(is_train):
        for x, _ in tqdm(dataloader):
            x = x.to(device)
            l = loss(x, ae(x))
            if (is_train):
                opt.zero_grad()
                l.backward()
                opt.step()
            total_loss += l.item()
    return total_loss / len(dataloader.dataset)

In [None]:
# отрисовка лосса
def plot_loss(loss, title, num_epochs):
    plt.title(title)
    plt.plot(loss)
    plt.grid()
    plt.xticks(np.arange(num_epochs))

def plot_losses(train, val, num_epochs):
    plt.figure(figsize=(16, 4))
    plt.subplot(1, 2, 1)
    plot_loss(train, f'Train Loss = {train[-1]}', num_epochs)
    plt.subplot(1, 2, 2)
    plot_loss(val, f'Val Loss = {val[-1]}', num_epochs)
    plt.show()

In [None]:
# отрисовка изображений - реальных и после автоэнкодера
def show_examples(ae, dataset, size):
    ae.eval()
    with torch.no_grad():
        idxs = np.random.randint(0, len(dataset), size)
        x = torch.stack([dataset[i][0] for i in idxs]).to(device)
        y = ae.process(x)
        print("Original images")
        ts.show(x, nrows=1, figsize=(12, 2))
        print("Reconstructed")
        ts.show(y, nrows=1, figsize=(12, 2))

In [None]:
from IPython.display import clear_output

def run_train_loop(ae, opt, loss, train_loader, val_loader, num_epochs, ex_size):
    train_hist = []
    val_hist = []
    for e in range(num_epochs):
        print("Trainin...")
        train_loss = run_epoch(ae, opt, loss, train_loader)
        train_hist.append(train_loss)
        print("Validating...")
        val_loss = run_epoch(ae, opt, loss, val_loader, is_train=False)
        val_hist.append(val_loss)
        clear_output()
        plot_losses(train_hist, val_hist, num_epochs)
        show_examples(ae, val_loader.dataset, ex_size)

In [None]:
batch_size = 256
num_epochs = 10
lat_dim = 8 # размерность скрытого пространства
ex_size = 8 # число примеров для отрисовки

train_loader = DataLoader(mnist_train, batch_size)
val_loader = DataLoader(mnist_val, batch_size)

sae = SimpleAutoEncoder(lat_dim).to(device)
sae_opt = optim.Adam(sae.parameters())
sae_loss = nn.MSELoss(reduction='sum')

run_train_loop(sae, sae_opt, sae_loss,
               train_loader, val_loader, num_epochs, ex_size)

In [None]:
data, labels = next(iter(val_loader))

In [None]:
img = data[0]
sae.encode(img)

Посмотрим, как выглядит скрытое пространство автоэнкодера:

In [None]:
def show_enc_dist(ae, xi, yi, dataset, size):
    idxs = np.random.randint(0, len(dataset), size)
    ae.eval()
    with torch.no_grad():
        z = ae.encode(torch.stack([dataset[i][0] for i in idxs]).to(device))
    x = z[:,xi].cpu().numpy()
    y = z[:,yi].cpu().numpy()
    plt.figure(figsize=(8, 8))
    plt.scatter(x, y, c = [dataset[i][1] for i in idxs])
    plt.show()

In [None]:
show_enc_dist(sae, 0, 1, mnist_val, 10000)

In [None]:
class SimpleVariationalAutoEncoder(nn.Module):
    def __init__(self, latent_dim, img_shape=(1, 28, 28)):
        super().__init__()

        self.encoder_common = nn.Sequential(
            nn.Flatten(),
            nn.Linear(np.prod(img_shape), 512),
            nn.LeakyReLU(),
            nn.Linear(512, 512),
            nn.LeakyReLU()
        )
        self.encoder_m = nn.Linear(512, latent_dim)
        self.encoder_s = nn.Linear(512, latent_dim)

        self.N = torch.distributions.Normal(0, 1)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 512),
            nn.LeakyReLU(),
            nn.Linear(512, np.prod(img_shape)),
            nn.Sigmoid(),
            nn.Unflatten(1, img_shape)
        )

    def encode(self, x):
        x = self.encoder_common(x)
        m = self.encoder_m(x)
        s = torch.exp(self.encoder_s(x))
        return m + s * self.N.sample(m.shape)

    def forward(self, x):
        x = self.encoder_common(x)
        m = self.encoder_m(x)
        s = torch.exp(self.encoder_s(x))
        z = m + s * self.N.sample(m.shape)
        y = self.decoder(z)
        return y, m, s

    def process(self, x):
        return self(x)[0]