In [None]:
"""
This notebook contains code for the implementation and training of a simple linear model, as well as various utility.
"""

from math import ceil
import numpy as np
import random

from os.path import getsize
import glob
from pathlib import Path

from collections import defaultdict

import matplotlib.pyplot as plt

from skimage import data, img_as_float
import cv2

import torch
from torch.utils.data import TensorDataset, SubsetRandomSampler

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import sys

In [None]:
def get_tileset(fp):
    with open(fp, 'rb') as _if:
        n_tiles = ord(_if.read(1))
        n_rows = ord(_if.read(1))
        n_cols = ord(_if.read(1))
        
        n_pixels = n_rows * n_cols
        
        tileset = []
        
        for i in range(n_tiles):
            tile = np.frombuffer(_if.read(n_pixels), dtype = np.ubyte)    
            tileset.append(tile)
        
        tileset = np.asarray(tileset)
        
        return n_rows, n_cols, tileset

def get_xy_pairs(fp):
    with open(fp, 'rb') as _if:       
        
        n_rows = np.fromfile(_if, np.int32, 1)[0]
        n_cols = np.fromfile(_if, np.int32, 1)[0]
        
        n_pixels = n_rows * n_cols
        
        assert( (getsize(fp)-8) % (4*(n_pixels+1)) == 0 )
        
        n_pairs = int((getsize(fp) - 8) / (4*(n_pixels + 1)))
  
        xs = []
        ys = []
        
        for i in range(n_pairs):
            xs.append(np.fromfile(_if, np.float32, n_pixels))
            ys.append(np.fromfile(_if, np.int32, 1))
        
        return n_rows, n_cols, np.asarray(xs), np.squeeze(np.asarray(ys))

def show_pair(x, y):
    
    plt.figure()
    
    if x.dtype == np.float32:
        x = (x*255).astype(np.ubyte)
    if y.dtype == np.float32:
        y = (y*255).astype(np.ubyte)
    
    plt.subplot(121)
    plt.imshow(x, cmap='gray', vmin=0, vmax=255)
    plt.subplot(122)
    plt.imshow(y, cmap='gray', vmin=0, vmax=255)
    
def show_imgs(imgs, n_cols=4, shape=None):
    
    plt.figure()
    
    n = len(imgs)    
    n_rows = ceil(n / n_cols)
    
    for i, img in enumerate(imgs):

        code = n_rows*100 + n_cols*10 + i + 1
   
        plt.subplot(code)        
        plt.imshow(img.reshape(shape), cmap='gray', vmin=0, vmax=255)
        
def get_loaders(xs, ys, batch_size = 64, val_split=.2):  
    
    if isinstance(xs, np.ndarray):   
        xs = torch.from_numpy(xs)
    if isinstance(ys, np.ndarray):
        ys = torch.from_numpy(ys).float()
    
    idx = torch.randperm(len(xs))
    
    xs = xs[idx]
    ys = ys[idx]
    
    dataset = TensorDataset(torch.Tensor(xs), torch.Tensor(ys))
    
    split = int(np.floor(val_split * len(xs)))
    
    train_indices, val_indices = idx[split:], idx[:split]
    
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=val_sampler)
    
    return train_loader, val_loader
        

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.fc1 = nn.Linear(195, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.fc4 = nn.Linear(256, 95)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        
        return x

def get_accuracy(model, loader):
    
    with torch.no_grad():
        
        num_correct = 0
        num_ys = 0
        
        for i, data in enumerate(loader):
            xs, ys = data[0].to(device), data[1].long().to(device)

            outputs = model(xs)
            
            
            outputs = torch.argmax(outputs, 1)
            
            num_correct += sum(outputs == ys)
            num_ys += len(ys)
            
        print(f'Got {num_correct} / {num_ys} with accuracy {float(num_correct)/float(num_ys)*100:.2f}') 
    
    return

def get_topk_accuracy(model, loader, k=3):
    
    num_correct = 0
    num_ys = 0

    with torch.no_grad():
        for i, data in enumerate(loader):
            xs, ys = data[0].to(device), data[1].long().to(device)

            outputs = model(xs)

            outputs = torch.topk(outputs, k)[1]
            
            for i in range(len(ys)):
                if (outputs[i]==ys[i]).any():
                    num_correct += 1
            
            num_ys += len(ys)

        print(f'Got {num_correct} / {num_ys} with top-{k} accuracy {float(num_correct)/float(num_ys)*100:.2f}') 

In [None]:
"""Training!"""

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)

val_losses = []
val_max = 3

for epoch in range(200):
    
    running_loss = 0.0
    
    for i, data in enumerate(train_loader, 0):
        xs, ys = data[0].to(device), data[1].long().to(device)
        
        optimizer.zero_grad()
        
        outputs = model(xs)
        loss = criterion(outputs, ys)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print('[%d] training loss: %.3f' %
          (epoch + 1, running_loss / len(train_loader)))
    
    if epoch % 5 == 4:

        val_loss = 0.0
        
        with torch.no_grad():
            
            
            for i, data in enumerate(val_loader, 0):
                xs, ys = data[0].to(device), data[1].long().to(device)

                outputs = model(xs)
                loss = criterion(outputs, ys)

                val_loss += loss.item()

                    
            print('[%d] total val loss: %.3f' %
                  (epoch + 1, val_loss / len(val_loader)))
        
        torch.save(model, 'model{}.pt'.format(epoch))
        
  
            
print('Finished')

In [None]:
def convert_to_ascii(img, model, tileset, tile_shape, slice_shape):
    
    tile_rows, tile_cols = tile_shape
    slice_rows, slice_cols = slice_shape
    
    x_pad = slice_cols - tile_cols
    y_pad = slice_rows - tile_rows
    
    rows = (img.shape[0] - 2*y_pad) // tile_rows
    cols = (img.shape[1] - 2*x_pad) // tile_cols

    
    
    ascii_img = []
    
    with torch.no_grad():
        
        for row in range(rows):
            ascii_img.append([])

            for col in range(cols):
                
                patch = img[ tile_rows*row : tile_rows*row + slice_rows, tile_cols*col: tile_cols*col + slice_cols ]
                
                tile_ind = torch.argmax(model(torch.from_numpy(patch).flatten().float().to(device)))
                
                
                ascii_img[-1].append(tileset[tile_ind].reshape(tile_rows, tile_cols))
            
            ascii_img[-1] = np.hstack(ascii_img[-1])
            

    ascii_img = np.vstack(ascii_img)
  
    return ascii_img

def get_indices(img, model, tileset, tile_shape, slice_shape):
    
    tile_rows, tile_cols = tile_shape
    slice_rows, slice_cols = slice_shape
    
    x_pad = slice_cols - tile_cols
    y_pad = slice_rows - tile_rows
    
    rows = (img.shape[0] - 2*y_pad) // tile_rows
    cols = (img.shape[1] - 2*x_pad) // tile_cols

    
    
    indices= []
    
    with torch.no_grad():
        
        for row in range(rows):
            for col in range(cols):
                
                patch = img[ tile_rows*row : tile_rows*row + slice_rows, tile_cols*col: tile_cols*col + slice_cols ]
                
                indices.append(torch.argmax(model(torch.from_numpy(patch).flatten().float().to(device))))

            
    return indices
    

def asciivate(fp, process=False):
    
    img = cv2.imread(fp)
    
    if img is None:
        return

    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    if process:

        img = cv2.bilateralFilter(img, 7, 400, 400)
        img = cv2.Canny(img, 8, 15)
        img = cv2.GaussianBlur(img, (3, 3), 0)

    img = img_as_float(~img)

    ascii_img = convert_to_ascii(img, model, tileset, (tile_rows, tile_cols), (slice_rows, slice_cols))
    
    return ascii_img

In [None]:
def save_raw_model(fp, model):
    with open(fp, 'wb') as _if:
        for key, val in model.state_dict().items():

            
            dims = val.shape
            
            if len(dims) == 2:
                val = val.transpose(0, 1)
                dims = val.shape
            
            print(key)
            print(val.shape)          
            
            if 'weight' in key:
                for dim in dims:
                    _if.write(int(dim).to_bytes(4, sys.byteorder))
                
            print(len(val.flatten().numpy().tobytes()))
            _if.write(val.flatten().numpy().tobytes())