```powershell
# Powershell command to download dataset. 
# Just a remark for windows. Do not work with jupyter
if ( !(Test-Path -Path .\toronto.tar) -and !(Test-Path -Path .\cifar-10-batches-py) ) 
{
    wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz -OutFile toronto.tar
    tar -xzf toronto.tar
}
```

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pickle
import numpy as np
import tqdm
from IPython.display import HTML
from collections import defaultdict

from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd
import plotly.express as px

In [2]:
def plot_history(h, *metrics):
    "function for showing model history on plot with metrics"
    for metric in metrics:
        print(f"{metric}: {h[metric][-1]:.4f}")
    figure = make_subplots(rows=1, cols=len(metrics), subplot_titles=metrics)
    for i, metric in enumerate(metrics, 1):
        ax = np.arange(0, len(h[metric]), 1)
        figure.add_trace(
            go.Scatter(x=ax, y=h[metric], name=metric), col=i, row=1
        )
    figure.update_layout(title_text="Метрики")
    figure.show()

In [3]:
def unpickle(file):
    with open(file, 'rb') as f:
        obj = pickle.load(f, encoding='latin1')
    return obj

def load_dataset(path, class_label):
    datadict = unpickle(path)
    data = datadict['data']
    labels = datadict['labels']
    
    dataset = []
    for image, label in zip(data, labels):
        if label == class_label:
            image = np.asarray(image, dtype=np.float32)
            image = (image - 127.5) / 127.5
            dataset.append(image)
    return dataset

def save(obj, name):
    with open(name, 'wb') as f:
        pickle.dump(obj, f)
        
def load(name):
    with open(name, 'rb') as f:
        return pickle.load(f)

In [4]:
def prepare_image(image_arr, width=32, height=32) -> go.Image:
    # разделяем пиксели картинки по цветов.
    image = np.transpose(np.array(image_arr).reshape(3, height, width), (1,2,0))
    image = (image + 1) * 255 / 2

    return go.Image(z=image)

def plot_images(images_arr, global_title="Картинки", titles=[""]*7, size_to_display=7):
    # ограничим количество картинок на выходе
    to_display = min(size_to_display, 13)

    figure = make_subplots(rows=1, cols=to_display, subplot_titles=titles)
    for col in range(1, to_display+1):
        image = images_arr[col-1]
        figure.add_trace(prepare_image(image), col=col, row=1)
    figure.update_layout(title_text=global_title)
    return figure

def plot_diff(model, images_arr, device_type, size_to_display=7):
    to_display = min(size_to_display, 13)
    indexes = np.random.randint(len(images_arr), size=to_display)

    figure = make_subplots(rows=2, cols=to_display)

    for col in range(1, to_display+1):
        image = images_arr[col-1]
        figure.add_trace(prepare_image(image), col=col, row=1)

    for col in range(1, to_display+1):
        image = model(torch.from_numpy(images_arr[col-1]).to(device_type).unsqueeze(0)).detach().cpu()[0].numpy()
        figure.add_trace(prepare_image(image), col=col, row=2)

    figure.update_layout(title_text="Разница")
    figure.show() 

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [6]:
train_data = load_dataset('cifar-10-batches-py/data_batch_1', 4)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

In [7]:
plot_images(train_data).show()

In [8]:
def fit(model, optim, crit, epochs, data):
    model.train()
    history = defaultdict(list)
    pbar = tqdm.trange(epochs, ascii=True)
    for i in pbar:
        avg_loss = 0
        for batch in data:
            batch = batch.to(device)
            
            optim.zero_grad()
            
            output = model(batch)
            loss = crit(batch, output)
            loss.backward()
            
            optim.step()
            avg_loss += loss.item() / len(data)
        history["loss"].append(avg_loss)
        pbar.set_description(f'Loss: {avg_loss:.8f}')
    
    with torch.no_grad():
        torch.cuda.empty_cache()
    return history

In [9]:
class ConvAutoencoder(nn.Module):
    def __init__(self, latent_space):
        super(ConvAutoencoder, self).__init__()
        # self.latent_space = latent_space
        self.Encoder = nn.Sequential(
            nn.Unflatten(1, (3, 32, 32)),
            nn.Conv2d(3, 8, kernel_size=3),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.MaxPool2d(2, 2),
            nn.ReLU()
        )
        self.CNN, self.CNN_flatten = self._get_conv_output((3072,), self.Encoder)
        self.Encoder.append(nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.CNN_flatten, latent_space),
            nn.Tanh()
        ))
        
        self.Decoder = nn.Sequential(
            # nn.Linear(latent_space, self.CNN_flatten),
            # nn.Unflatten(1, self.CNN[1:]),
            # nn.Upsample(scale_factor=2),
            # nn.ConvTranspose2d(32, 16, kernel_size=3),
            # nn.Upsample(scale_factor=2),
            # nn.ConvTranspose2d(16, 8, kernel_size=3),
            # nn.Upsample(scale_factor=2),
            # nn.ConvTranspose2d(8, 3, kernel_size=5),
            # nn.Flatten(),
            # nn.Tanh()
            nn.Linear(latent_space, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 3072),
            nn.Tanh()
        )
        
    def _get_conv_output(self, shape, layers):
        bs = 1
        dummy_x = torch.empty(bs, *shape)
        x = layers(dummy_x)
        CNN = x.size()
        CNN_flatten = x.flatten(1).size(1)
        return CNN, CNN_flatten
    
    def forward(self, x):
        return self.Decoder(self.Encoder(x))

In [22]:
encoder = nn.Sequential(
    nn.Linear(3072, 1024*10),
    nn.Tanh(),
    nn.Linear(1024*10, 1024*4),
    nn.Tanh(),
    nn.Linear(1024*4, 1024*2),
    nn.Tanh(),
    nn.Linear(1024*2, 1024*1),
    nn.Tanh(),
    nn.Linear(1024*1, 96),
    nn.Tanh()
)

decoder = nn.Sequential(
    nn.Linear(96, 1024*1),
    nn.Tanh(),
    nn.Linear(1024*1, 1024*2),
    nn.Tanh(),
    nn.Linear(1024*2, 1024*4),
    nn.Tanh(),
    nn.Linear(1024*4, 1024*10),
    nn.Tanh(),
    nn.Linear(1024*10, 3072),
    nn.Tanh()
)

model = nn.Sequential(encoder, decoder).to(device)

In [23]:
hist = fit(model, torch.optim.Adam(model.parameters(), lr=1e-4), nn.MSELoss(), 500, train_loader)

Epoch: 500. Loss: 0.01092679: 100%|##########| 500/500 [11:16<00:00,  1.35s/it]


In [24]:
plot_history(hist, "loss")

loss: 0.0109


In [25]:
torch.save(model, './models/liner_big.pkl')

In [11]:
model = torch.load("./models/liner.pkl", map_location=device)

In [34]:
def modifications(decoder, encoder, image, latent_space, count=5):
    features = encoder(torch.from_numpy(image).to(device).unsqueeze(0)).detach().cpu()[0].numpy()
    images = [image, decoder(torch.from_numpy(features).to(device).unsqueeze(0)).detach().cpu()[0].numpy()]
    titles = ['Input', 'Output', 'Modifications', '', '', '', '']
    for _ in range(5):
        idx = np.random.randint(latent_space, size=count)
        mod = np.copy(features)
        mod[idx] = np.random.rand(count) * 2 - 1
        images.append(decoder(torch.from_numpy(mod).to(device).unsqueeze(0)).detach().cpu()[0].numpy())
    plot_images(images, titles=titles, global_title=f"Модификация {count}").show()

    

In [27]:
model.eval()
plot_diff(model, train_data,device_type=device)

In [53]:
encoder = model[0]
decoder = model[1]
modifications(decoder, encoder, train_data[np.random.randint(len(train_data))], 96, 24)

In [45]:
from jupyter_dash import JupyterDash
from dash import dcc, html, Input, Output, Dash

external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)

app.layout = html.Div([
    dcc.Slider(10, 200, 5,
               value=90,
               id='my-slider'
    ),
    html.Div(id='slider-output-container')
])

@app.callback(
    Output('slider-output-container', 'children'),
    Input('my-slider', 'value'))
def update_output(value):
    modifications(decoder, encoder, train_data[np.random.randint(len(train_data))], value, 5)
    return 'You have selected "{}"'.format(value)

In [15]:
if __name__ == '__main__':
    app.run_server()

Dash app running on http://127.0.0.1:8050/
