In [2]:
import torch

In [8]:
state_dict = torch.load('trained_models/trained_model.pt')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, val in state_dict.items():
    print(k)
    for v in val:
        print(v)
        print('-'*100)
    print('-'*100)
#     name = k[7:] # remove `module.`
#     new_state_dict[name] = v

state_dict
lstm1_layer.attention.query_projection.weight
----------------------------------------------------------------------------------------------------
lstm1_layer.attention.query_projection.bias
----------------------------------------------------------------------------------------------------
lstm1_layer.attention.key_projection.weight
----------------------------------------------------------------------------------------------------
lstm1_layer.attention.key_projection.bias
----------------------------------------------------------------------------------------------------
lstm1_layer.attention.value_projection.weight
----------------------------------------------------------------------------------------------------
lstm1_layer.attention.value_projection.bias
----------------------------------------------------------------------------------------------------
lstm1_layer.attention.out_projection.weight
-------------------------------------------------------------------------

In [None]:
import torch
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
from torch.nn.utils import clip_grad_norm_
from .dataset import IAMDataset
from .model import HandwritingGenerator
from .loss import HandwritingLoss
from copy import deepcopy
from .utils import plotstrokes
from typing import Union
from pathlib import Path

In [None]:
class Trainer:
    def __init__(self, parameters):

        self.params = parameters

        # Initialize datasets
        self.trainset = IAMDataset(self.params)

        self.alphabet = self.trainset.alphabet
        alphabet_size = len(self.alphabet)

        # Initialize loaders
        self.trainloader = DataLoader(
            self.trainset,
            batch_size=self.params.batch_size,
            shuffle=False,
            num_workers=self.params.num_workers,
            sampler=RandomSampler(self.trainset),
        )

        # Checking for GPU
        self.use_gpu = self.params.use_gpu and torch.cuda.is_available()
        self.device = torch.device("cuda:0" if self.use_gpu else "cpu")

        try:
            path = self.params.model_dir / f"trained_model_{self.params.idx}.pt"
            print(f"Loading model : {path}")
            self.model = self.load_model(path)
            print(f"'{path}' model loaded!")
        except Exception as e:
            print(f"{path} model not found.")
            print("Creating new model.")
            self.model = HandwritingGenerator(
                alphabet_size=alphabet_size,
                hidden_size=self.params.hidden_size,
                num_window_components=self.params.num_window_components,
                num_mixture_components=self.params.num_mixture_components,
            )

        self.model.to(self.device)

        print(self.model)

        print("Number of parameters = {}".format(self.model.num_parameters()))

        # Optimizer setup
        self.optimizer = self.optimizer_select()

        # Criterion
        self.criterion = HandwritingLoss(self.params)

    def train_model(self):
        min_loss = None
        best_model = self.model.state_dict()
        avg_losses = np.zeros(self.params.num_epochs)
        self.params.model_dir.mkdir(parents=True, exist_ok=True)
        path = self.params.model_dir / f"trained_model_{self.params.idx}.pt"
        for epoch in range(self.params.num_epochs):
            try:
                print("Epoch {}".format(epoch + 1))

                # Set mode to training
                self.model.train()

                # Go through the training set
                avg_losses[epoch] = self.train_epoch()

                print("Average loss = {:.3f}".format(avg_losses[epoch]))

                if min_loss is None or min_loss >= avg_losses[epoch]:
                    min_loss = avg_losses[epoch]
                    best_model = self.model.state_dict()

                if (epoch + 1) % 5 == 0:
                    self.save_model(best_model, path)

            except KeyboardInterrupt:
                print("Training was interrupted")
                break
        # Saving trained model
        print("Saving model...")
        self.save_model(best_model, path)
        return avg_losses

    def train_epoch(self):
        losses = 0.0
        inf = float("inf")
        for batch_index, (data) in enumerate(self.trainloader, 1):
            if batch_index % 20 == 0:
                print("Step {}".format(batch_index))
                print("Average Loss so far: {}".format(losses / batch_index))
            # Split data tuple
            onehot, strokes = data
            # Plot strokes
            # plotstrokes(strokes)
            # Move inputs to correct device
            onehot, strokes = onehot.to(self.device), strokes.to(self.device)
            # Main Model Forward Step
            self.model.reset_state()
            loss = None
            for idx in range(strokes.size(1) - 1):
                output, _ = self.model(strokes[:, idx : idx + 1, :], onehot)
                # Loss Computation
                loss = (
                    self.criterion(output, strokes[:, idx + 1 : idx + 2, :])
                    / strokes.size(1)
                    if loss is None
                    else loss
                    + self.criterion(output, strokes[:, idx + 1 : idx + 2, :])
                    / strokes.size(1)
                )
            if loss.data.item() == inf or loss.data.item() == -inf:
                print("Warning, received inf loss. Skipping it")
            elif loss.data.item() != loss.data.item():
                print("Warning, received NaN loss.")
            else:
                losses = losses + loss.data.item()
            # Zero the optimizer gradient
            self.optimizer.zero_grad()
            # Backward step
            loss.backward()
            # Clip gradients
            clip_grad_norm_(self.model.parameters(), self.params.max_norm)
            # Weight Update
            self.optimizer.step()
            if self.use_gpu is True:
                torch.cuda.synchronize()
            del onehot, strokes, data
        # Compute the average loss for this epoch
        avg_loss = losses / len(self.trainloader)
        return avg_loss

    @staticmethod
    def load_model(path: Union[Path, str]):
        package = torch.load(path, map_location=lambda storage, loc: storage)
        parameters = package["parameters"]
        state_dict = package["state_dict"]
        return HandwritingGenerator.load_model(parameters, state_dict)

    def save_model(self, model_parameters, path):
        model = deepcopy(self.model)
        model.load_state_dict(model_parameters)
        torch.save(
            self.serialize(model), path,
        )

    def serialize(self, model):
        model_is_cuda = next(model.parameters()).is_cuda
        model = model.cpu() if model_is_cuda else self.model
        package = {
            "state_dict": model.state_dict(),
            "optim_dict": self.optimizer.state_dict(),
            "parameters": {
                "alphabet_size": self.model.alphabet_size,
                "hidden_size": self.model.hidden_size,
                "num_window_components": self.model.num_window_components,
                "num_mixture_components": self.model.num_mixture_components,
            },
        }
        return package

    def optimizer_select(self):
        if self.params.optimizer == "Adam":
            return optim.Adam(self.model.parameters(), lr=self.params.learning_rate)
        elif self.params.optimizer == "Adadelta":
            return optim.Adadelta(self.model.parameters(), lr=self.params.learning_rate)
        elif self.params.optimizer == "SGD":
            return optim.SGD(
                self.model.parameters(),
                lr=self.params.learning_rate,
                momentum=self.params.momentum,
                nesterov=self.params.nesterov,
            )
        elif self.params.optimizer == "RMSprop":
            return optim.RMSprop(
                self.model.parameters(),
                lr=self.params.learning_rate,
                momentum=self.params.momentum,
            )
        else:
            raise NotImplementedError


In [1]:
breakpoint()

--Call--
> [0;32m/home/iorua8/.local/lib/python3.7/site-packages/IPython/core/displayhook.py[0m(252)[0;36m__call__[0;34m()[0m
[0;32m    250 [0;31m        [0msys[0m[0;34m.[0m[0mstdout[0m[0;34m.[0m[0mflush[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    251 [0;31m[0;34m[0m[0m
[0m[0;32m--> 252 [0;31m    [0;32mdef[0m [0m__call__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mresult[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    253 [0;31m        """Printing with history cache management.
[0m[0;32m    254 [0;31m[0;34m[0m[0m
[0m
ipdb> f
*** NameError: name 'f' is not defined
ipdb> 
*** NameError: name 'f' is not defined
ipdb> d
*** Newest frame
ipdb> d
*** Newest frame
ipdb> c
