In [1]:
from nnhcs import *

ModuleNotFoundError: No module named 'nnhcs'

In [None]:
# import tqdm
from collections import namedtuple
from dataclasses import dataclass, field
from enum import Enum
from IPython import display
from PIL import Image, ImageDraw, ImageFont
from tqdm.notebook import tqdm as tqdm_nb
import tqdm
from typing import Any
import dataclasses
from dataclasses import field
import imageio
import itertools
from loguru import logger
import numpy as np
import operator
import os
import pathlib
import pickle
import queue
import random
import shelve
import shutil
import subprocess
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import typing
import uuid
from collections import OrderedDict
import collections
import numpy as np
import gc

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def put_shelve(name, data):
    with shelve.open('my.db') as db:
        db[name] = data

def get_shelve(name):
    with shelve.open('my.db') as db:
        return db[name]

In [None]:
def put_pickle(name, data):
    with open(name + '.pkl', 'wb') as f:
        pickle.dump(data, f)
        
def get_pickle(name):
    with open(name + '.pkl', 'rb') as f:
        return pickle.load(f)

In [None]:
class Algorithm(object):
    def get_next_pos(self, curr, goal, env):
        raise NotImplementedError
        
class LRTA(Algorithm):
    def __init__(self, weight=1):
        self.h = None
        self.h_updates = None
        self.weight = weight
        
    def init_h(self, env, goal):
        self.h = np.zeros_like(env.bitmap)
        self.h_updates = np.zeros_like(env.bitmap)
        for pos in env:
            if env.get(pos) == World.ROAD:
                self.h[pos.x, pos.y] = Pos.heuristic(pos, goal)
        
    def get_next_pos(self, curr, goal, env):
        adjs = env.get_adjs(curr)
        if not adjs:
            raise ValueError
        
        adj_vals = []
        for adj in adjs:
            g = env.get_cost(curr, adj)
            h = self.h[adj.x, adj.y]
            f = (g + h) * self.weight
            adj_vals.append((f, adj))
            
        adj_vals.sort(key=operator.itemgetter(0))
        _, best_group = next(itertools.groupby(adj_vals, key=operator.itemgetter(0)))
        best_f, best_adj = random.choice(list(best_group))
            
        if best_f > self.h[curr.x, curr.y]:
            diff = best_f - self.h[curr.x, curr.y]
            self.h_updates[curr.x, curr.y] += diff
            self.h[curr.x, curr.y] = best_f
#             print(f"{curr} -> {f}")
        return best_adj

In [None]:
class Path(object):
    def __init__(self, world, start, goal, poses, hcses=None):
        self.world = world
        self.start = start
        self.goal = goal
        self.poses = poses
        self.hcses = hcses
        
    def vis(self):
        return Vis([
            BitmapLayer(self.world.bitmap),
            SceneLayer(self.start, self.goal),
            PathLayer(self.poses),
            SubgoalLayer(self.hcses),
        ])

In [None]:
def run(world, scene, algo):
    algo.init_h(world, scene.goal)
    curr_pos = scene.start
    total_cost = 0
    path = [curr_pos]
    time = 0
    while curr_pos != scene.goal:
#         print(curr_pos)
        next_pos = algo.get_next_pos(curr_pos, scene.goal, world, time)
        assert world.is_adj(curr_pos, next_pos)
        total_cost += world.get_cost(curr_pos, next_pos)
        path.append(next_pos)
        curr_pos = next_pos
        time += 1
    return Replay(world, scene, path, algo.h_updates, total_cost)

In [None]:
class VisLayer(object):
    @classmethod
    def create_empty_image(cls, shape, scale):
        return Image.new('RGBA', (scale * shape[0], scale * shape[1]))

    # @property
    # def shape(self):
    #     raise NotImplementedError

    def _get_point(self, x, y, scale):
        return (x * scale, y * scale)

    def _get_mid_point(self, x, y, scale):
        return (x * scale + scale / 2, y * scale + scale / 2)

    def _get_rect(self, x, y, scale):
        return [self._get_point(x, y, scale), self._get_point(x + 1, y + 1, scale)]


class BitmapLayer(VisLayer):
    def __init__(self, bitmap):
        self.bitmap = bitmap

    @property
    def shape(self):
        return self.bitmap.shape

    def get_image(self, shape, scale):
        assert shape == self.shape
        image = self.create_empty_image(self.shape, scale)
        draw = ImageDraw.Draw(image)
        xs, ys = self.shape
        for x in range(xs):
            for y in range(ys):
                val = self.bitmap[x, y]
                if val == World.WALL:
                    fill = 'grey'
                elif val == World.ROAD:
                    fill = 'white'
                else:
                    raise ValueError
                if scale >= 8:
                    draw.rectangle(self._get_rect(x, y, scale), fill=fill, outline='grey', width=1)
                else:
                    draw.rectangle(self._get_rect(x, y, scale), fill=fill)
        return image


class SceneLayer(VisLayer):
    def __init__(self, start, goal):
        self.start = start
        self.goal = goal

    def get_image(self, shape, scale):
        image = self.create_empty_image(shape, scale)
        draw = ImageDraw.Draw(image)
        draw.rectangle(self._get_rect(*self.start, scale), fill='green')
        draw.rectangle(self._get_rect(*self.goal, scale), fill='red')
        return image


class PathLayer(VisLayer):
    def __init__(self, path):
        self.path = path

    def get_image(self, shape, scale):
        image = self.create_empty_image(shape, scale)
        draw = ImageDraw.Draw(image)
        segments = []
        for lhs, rhs in zip(self.path[:-1], self.path[1:]):
            segments.append((self._get_mid_point(lhs.x, lhs.y, scale),
                             self._get_mid_point(rhs.x, rhs.y, scale)))
        for seg in segments:
            draw.line(seg, fill='red', width=int(scale / 16))
        return image


class ColorOverlayLayer(VisLayer):
    def __init__(self, values, color=(255, 0, 0)):
        self.values = values
        self.color = color

    @property
    def shape(self):
        return self.values.shape

    def get_image(self, shape, scale):
        assert shape == self.shape
        image = self.create_empty_image(shape, scale)
        draw = ImageDraw.Draw(image)
        xs, ys = shape
        for x in range(xs):
            for y in range(ys):
                value = int(self.values[x, y])
                fill = (*self.color, min(max(value, 0), 255))
                draw.rectangle(self._get_rect(x, y, scale), fill=fill)
        return image
    
class MultiColorOverlayLayer(VisLayer):
    def __init__(self, values, alpha):
        self.values = values
        self.alpha = alpha
    
    @property
    def shape(self):
        return self.values.shape
    
    def get_image(self, shape, scale):
        assert shape == self.shape[:-1]
        image = self.create_empty_image(shape, scale)
        draw = ImageDraw.Draw(image)
        xs, ys = shape
        for x in range(xs):
            for y in range(ys):
                value = map(int, self.values[x, y])
                fill = (*value, self.alpha)
                draw.rectangle(self._get_rect(x, y, scale), fill=fill)
        return image
    

class SubgoalLayer(VisLayer):
    def __init__(self, subgoals):
        self.subgoals = subgoals

    def get_image(self, shape, scale):
        m = np.zeros(shape)
        for subgoal in self.subgoals:
            m[subgoal.x, subgoal.y] += 1
        m = normalize_alpha(m, upper=255)
        return ColorOverlayLayer(m, color=(155, 0, 255)).get_image(shape, scale)


class NoteLayer(VisLayer):
    def __init__(self, note, font_scale=1):
        self.note = note
        self.font_scale = font_scale

    def get_image(self, shape, scale):
        image = self.create_empty_image(shape, scale)
        draw = ImageDraw.Draw(image)
        font = ImageFont.truetype('DroidSans.ttf', size=int(scale / 4) * self.font_scale)
        if isinstance(self.note, dict):
            for pos, text in self.note.items():
                text = str(text)
                draw.text(self._get_point(pos.x, pos.y, scale),
                          text=text, fill='blue', font=font)
        elif isinstance(self.note, np.ndarray):
            for x in range(self.note.shape[0]):
                for y in range(self.note.shape[1]):
                    text = str(self.note[x, y])
                    draw.text(self._get_point(x, y, scale),
                              text=text, fill='blue', font=font)
        return image


class Vis(object):
    def __init__(self, layers, scale=8):
        self.layers = layers
        self.scale = scale
        self.shape = layers[0].shape

    def get_image(self):
        canvas = VisLayer.create_empty_image(self.shape, self.scale)

        for layer in self.layers:
            image = layer.get_image(self.shape, self.scale)
            canvas = Image.alpha_composite(canvas, image)

        return canvas

    def save(self, open_=False, path=None):
        name = str(uuid.uuid4()) + '.png'
        path = path or pathlib.Path('temp_images/' + name)
        self.get_image().save(path)
        if open_:
            self.open(path)
        return path

    @staticmethod
    def open(path):
        subprocess.run(['open', str(path)], check=True)

    def display(self):
        display.display(self.get_image())


def normalize_alpha(x, upper=127):
    x = x.astype(np.float64)
    x -= np.min(x)
    x /= (np.max(x) + np.finfo(np.float64).eps)
    x *= upper
    return x

In [None]:
@dataclass(order=True)
class PrioritizedItem(object):
    priority: int
    g: int=field()
    h: int=field()
    pos: Any=field(compare=False)

In [None]:
def reconstruct(goal, came_from):
    if goal not in came_from:
        return []
    else:
        prev = came_from[goal]
        return reconstruct(prev, came_from) + [prev]

In [None]:
def get_opt_path(world, start, goal):
    frontier = queue.PriorityQueue()
    closed = {start: 0}
    came_from = {}
    h = Pos.heuristic(start, goal)
    frontier.put(PrioritizedItem(h, 0, h, start))

    while not frontier.empty():
        curr = frontier.get()
        if curr.pos == goal:
            poses = reconstruct(goal, came_from) + [goal]
            return Path(world, start, goal, poses)
        for adj in world.get_adjs(curr.pos):
            g = curr.g + world.get_cost(curr.pos, adj)
            if adj not in closed or g < closed[adj]:
                came_from[adj] = curr.pos
                closed[adj] = g
                h = Pos.heuristic(adj, goal)
                f = g + h
                frontier.put(PrioritizedItem(f, g, h, adj))
    return False

In [None]:
def sample_paths(world, limit=16, use_edges=False, verbose=False):
    tiles = list(filter(lambda t: world.get(t) == World.ROAD, world))
    if use_edges:
        tiles = list(filter(
            lambda t: any(world.get(adj) == World.WALL for adj in world.get_all_adjs(t)),
            tiles
        ))
    opt_paths = []
    for _ in tqdm_nb(list(range(limit)), desc="Building Optimal Paths", disable=not verbose):
        start = random.choice(tiles)
        goal = random.choice(list(set(tiles) - set([start])))
        opt_paths.append(get_opt_path(world, start, goal))
    return opt_paths

In [None]:
def hill_climbable(world, start, goal):
    visited = set()
    curr = start
    while True:
        if curr == goal:
            return True
        adjs = world.get_adjs(curr)
        vals = []
        for adj in adjs:
            g = world.get_cost(curr, adj)
            h = Pos.heuristic(adj, goal)
            vals.append(g + h)
        curr = adjs[np.array(vals).argmin()]
        if curr in visited:
            return False
        else:
            visited.add(curr)

In [None]:
GOAL_HCS = Pos(-1, -1)

In [None]:
def get_earliest_hcs_linear(path):
    for subgoal in reversed(path.poses):
        if hill_climbable(path.world, path.start, subgoal):
            return subgoal
    raise ValueError
    
def get_earliest_hcs(path):
    poses = path.poses
    if len(poses) == 0:
        return poses[0]
    elif len(poses) == 1:
        return poses[1]
    lower = 0
    upper = len(poses)
    while True:
        if upper - lower == 1:
            return poses[lower]
        idx = int((upper - lower) / 2) + lower
        climbable = hill_climbable(path.world, path.start, poses[idx])
        if climbable:
            lower = idx
        else:
            upper = idx

In [None]:
def get_all_hcs(path):
    hcses = []
    while True:
        hcs = get_earliest_hcs(path)
        if hcs == path.goal:
            return hcses
        hcses.append(hcs)
        start = hcs
        poses = path.poses[path.poses.index(hcs):] 
        path = Path(world=path.world, start=start, goal=path.goal, poses=poses)

In [None]:
@dataclasses.dataclass
class HCSPoint(object):
    start: Pos
    goal: Pos
    hcs: Pos
        
    def xy(self):
        x = np.array([pos_to_idx[self.start], pos_to_idx[self.goal]])
        y = np.array([pos_to_idx[self.hcs]])
        return x, y
    
    def vis(self, world):
        return Vis([
            BitmapLayer(world.bitmap),
            SceneLayer(self.start, self.goal),
            SubgoalLayer([self.hcs]),
        ])

In [None]:
def add_hcses_to_paths(paths, verbose=False):
    for path in tqdm_nb(paths, desc="Building HCS", disable=not verbose):
        path.hcses = get_all_hcs(path)

def convert_paths_to_points(paths, combination=False, convert_goal=False, verbose=False):
    points = []
    if combination:
        for path in tqdm_nb(paths, desc="Building Data Points", disable=not verbose):
            hcs_indices = [path.poses.index(hcs) for hcs in path.hcses]
            for i, start in enumerate(path.poses):
                for j, goal in enumerate(path.poses):
                    if i + 1 == len(path.pose.path) or j == 0:
                        continue
                    hcs = next(
                        (
                            path.poses[hcs_idx]
                            for hcs_idx in hcs_indices if hcs_idx > i
                        ),
                        goal
                    )
                    points.append(HCSPoint(start, goal, hcs))
    else:
        for path in tqdm_nb(paths, desc="Building Data Points", disable=not verbose):
            hcs_idx = 0
            for start in path.poses:
                if hcs_idx < len(path.hcses) and start == path.hcses[hcs_idx]:
                    hcs_idx += 1
                if hcs_idx < len(path.hcses):
                    hcs = path.hcses[hcs_idx]
                else:
                    hcs = path.goal
                points.append(HCSPoint(start, path.goal, hcs))
    if convert_goal:
        for point in points:
            if point.hcs == point.goal:
                point.hcs = GOAL_HCS
    return points

In [None]:
class StreamingPointsDataset(torch.utils.data.IterableDataset):
    def __init__(self, size=32, sample_rate=0.1):
        self.sample_rate = sample_rate
        self.size = size
        
    def __iter__(self):
        paths = sample_paths(world, limit=self.size, verbose=False)
        add_hcses_to_paths(paths)
        points = convert_paths_to_points(paths)
        points = random.sample(points, int(self.sample_rate * len(points)))
        for p in points:
            x, y = p.xy()
            yield (torch.LongTensor(x), torch.LongTensor(y))

In [None]:
class PointDataset(torch.utils.data.Dataset):
    def __init__(self, points):
        self.xy = []
        for p in points:
            x, y = p.xy()
            self.xy.append((torch.LongTensor(x),
                            torch.LongTensor(y)))
            
    
    def __getitem__(self, key):
        return self.xy[key]
    
    def __len__(self):
        return len(self.xy)

In [None]:
def plot_paths(world, paths):
    path_freq = np.zeros_like(world.bitmap)
    hcs_freq = np.zeros_like(world.bitmap)
    for path in paths:
        for pos in path.poses:
            path_freq[pos.x, pos.y] += 1
        for pos in path.hcses:
            hcs_freq[pos.x, pos.y] += 1
    path_freq = normalize_alpha(path_freq.astype(np.float64), upper=255)
    hcs_freq = normalize_alpha(np.log1p(hcs_freq.astype(np.float64)), upper=255)
    Vis([
        BitmapLayer(world.bitmap),
        ColorOverlayLayer(path_freq, color=(255, 0, 255)),
    ]).display()
    Vis([
        BitmapLayer(world.bitmap),
        ColorOverlayLayer(hcs_freq, color=(255, 0, 255)),
    ]).display()

In [None]:
def hill_climb(world, start, subgoal, goal, num_steps=None):
    visited = set([start])
    poses = [start]
    curr = start
    while True and (num_steps is None or len(poses) < num_steps + 1):
        if curr == goal or curr == subgoal:
            return poses
        adjs = world.get_adjs(curr)
        vals = []
        for adj in adjs:
            g = world.get_cost(curr, adj)
            h = Pos.heuristic(adj, subgoal)
            vals.append(g + h)
        curr = adjs[np.array(vals).argmin()]
        if curr in visited:
            return poses
        else:
            visited.add(curr)
            poses.append(curr)
    return poses

In [None]:
# selected_worlds = get_selected()
# world = selected_worlds['den020d.map']

# paths = sample_paths(world, limit=8)
# add_hcses_to_paths(paths)
# points = convert_paths_to_points(paths)

# for point in random.sample(points, 20):
#     point.vis(world).display()

In [None]:
def fix_subgoal(env, subgoal):
    if env.get(subgoal) == Map.ROAD:
        return subgoal
    
    front = [subgoal]
    closed = set()
    while front:
        target = front.pop()
        if env.get(target) == Map.ROAD:
            return target
        closed.add(target)
        for adj in env.get_all_adjs(target):
            if adj not in closed:
                front.append(adj)
                closed.add(adj)
    raise ValueError

def get_subgoal(model, start, goal):
    model.eval()
    with torch.no_grad():
        x = torch.LongTensor([pos_to_idx[start], pos_to_idx[goal]])
        x = x.reshape(1, *x.shape)
        out = model(x)
        out = idx_to_pos[out.numpy().argmax(-1)[0]]
        return out

In [None]:
def get_gif(images, fps=30, show=True):
    fname = 'temp_images/' + str(uuid.uuid4()) + '.gif'
    imageio.mimsave(fname, images, fps=fps)
    if show:
        display.display(display.Image(filename=fname))

In [None]:
class Model(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim=3,
        hidden_dim=32
    ):
        super().__init__()
        self.embed_dim = embed_dim
        
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.fc1 = nn.Linear(embed_dim * 2, hidden_dim)
        self.do1 = nn.Dropout()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.do2 = nn.Dropout()
        self.fc3 = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        # N, 2
        x = self.embed(x)
        # N, 2, embed_dim
        x = x.reshape(-1, self.embed_dim * 2)
        # N, embed_dim * 2
        x = self.fc1(x)
        x = self.do1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.do2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.log_softmax(x, dim=1)
        # N, vocab_size
        return x

In [None]:
class LargeEmbedModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim=32,
        hidden_dim=512,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.fc1 = nn.Linear(embed_dim * 2, hidden_dim)
        self.do1 = nn.Dropout()
        self.fc3 = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        # N, 2
        x = self.embed(x)
        # N, 2, embed_dim
        x = x.reshape(-1, self.embed_dim * 2)
        # N, embed_dim * 2
        x = self.fc1(x)
        x = self.do1(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.log_softmax(x, dim=1)
        # N, vocab_size
        return x

In [None]:
class SingleLayerEmbedModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim=32,
        hidden_dim=512,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim * 2, vocab_size)
    
    def forward(self, x):
        # N, 2
        x = self.embed(x)
        # N, 2, embed_dim
        x = x.reshape(-1, self.embed_dim * 2)
        # N, embed_dim * 2
        x = self.fc(x)
        x = F.log_softmax(x, dim=1)
        # N, vocab_size
        return x

In [None]:
class ResModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim=32,
#         hidden_dim=16,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.fcs = [nn.Linear(embed_dim * 2, embed_dim * 2) for _ in range(16)]
        self.out = nn.Linear(embed_dim * 2, vocab_size)
#         self.fcs = [nn.Linear()]
    
    def forward(self, x):
        # N, 2
        x = self.embed(x)
        # N, 2, embed_dim
        x = x.reshape(-1, self.embed_dim * 2)
        # N, embed_dim * 2
        for fc in self.fcs:
            x = fc(x)
        x = self.out(x)
        x = F.log_softmax(x, dim=1)
        # N, vocab_size
        return x

In [None]:
def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
    result, params_info = summary_string(
        model, input_size, batch_size, device, dtypes)
    print(result)

    return params_info


def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
    if dtypes == None:
        dtypes = [torch.FloatTensor]*len(input_size)

    summary_str = ''

    def register_hook(module):
        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
        ):
            hooks.append(module.register_forward_hook(hook))

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype).to(device=device)
         for in_size, dtype in zip(input_size, dtypes)]

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    # print(x.shape)
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

    summary_str += "----------------------------------------------------------------" + "\n"
    line_new = "{:>20}  {:>25} {:>15}".format(
        "Layer (type)", "Output Shape", "Param #")
    summary_str += line_new + "\n"
    summary_str += "================================================================" + "\n"
    total_params = 0
    total_output = 0
    trainable_params = 0
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
        line_new = "{:>20}  {:>25} {:>15}".format(
            layer,
            str(summary[layer]["output_shape"]),
            "{0:,}".format(summary[layer]["nb_params"]),
        )
        total_params += summary[layer]["nb_params"]

        total_output += np.prod(summary[layer]["output_shape"])
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        summary_str += line_new + "\n"

    # assume 4 bytes/number (float on cuda).
    total_input_size = abs(np.prod(sum(input_size, ()))
                           * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * total_output * 4. /
                            (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    summary_str += "================================================================" + "\n"
    summary_str += "Total params: {0:,}".format(total_params) + "\n"
    summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n"
    summary_str += "Non-trainable params: {0:,}".format(total_params -
                                                        trainable_params) + "\n"
    summary_str += "----------------------------------------------------------------" + "\n"
    summary_str += "Input size (MB): %0.2f" % total_input_size + "\n"
    summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n"
    summary_str += "Params size (MB): %0.2f" % total_params_size + "\n"
    summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n"
    summary_str += "----------------------------------------------------------------" + "\n"
    # return summary
    return summary_str, (total_params, trainable_params)