# 1

In [42]:
from torch.optim import SGD, Adam
import torch.nn.functional as F
import random
from tqdm import tqdm
import math
from sklearn.model_selection import train_test_split

def data_gen(X,y, batchsize):
    '''
    Generator for data
    '''
    for i in range(len(X)//batchsize):
        yield X[i*batchsize:(i+1)*batchsize],y[i*batchsize:(i+1)*batchsize]
    i+=1
    yield X[i*batchsize:],y[i*batchsize:]
        

class Trainer():
    def __init__(self, model, optimizer_type, learning_rate, epoch, batch_size, input_transform=lambda x: x,):
        """ The class for training the model
        model: nn.Module
            A pytorch model
        optimizer_type: 'adam' or 'sgd'
        learning_rate: float
        epoch: int
        batch_size: int
        input_transform: func
            transforming input. Can do reshape here
        """
        self.model = model
        if optimizer_type == "sgd":
            self.optimizer = SGD(model.parameters(), learning_rate,momentum=0.9)
        elif optimizer_type == "adam":
            self.optimizer = Adam(model.parameters(), learning_rate)
        elif optimizer_type == 'adam_l2':
            self.optimizer = Adam(model.parameters(), learning_rate, weight_decay=1e-5)
            
        self.epoch = epoch
        self.batch_size = batch_size
        self.input_transform = input_transform


    @timing
    def train(self, inputs, outputs, val_inputs, val_outputs,draw_curve=False,early_stop=False,l2=False,silent=False):
        """ train self.model with specified arguments
        inputs: np.array, The shape of input_transform(input) should be (ndata,nfeatures)
        outputs: np.array shape (ndata,)
        val_nputs: np.array, The shape of input_transform(val_input) should be (ndata,nfeatures)
        val_outputs: np.array shape (ndata,)
        early_stop: bool
        l2: bool
        silent: bool. Controls whether or not to print the train and val error during training
        """
        inputs = self.input_transform(torch.tensor(inputs, dtype=torch.float))
        outputs = torch.tensor(outputs, dtype=torch.int64)
        val_inputs = self.input_transform(torch.tensor(val_inputs, dtype=torch.float))
        val_outputs = torch.tensor(val_outputs, dtype=torch.int64)

        losses = []
        val_losses = []
        weights = self.model.state_dict()
        lowest_val_loss = np.inf
        
        for n_epoch in tqdm(range(self.epoch), leave=False):
            self.model.train()
            #shuffle the data in each epoch
            idx =torch.randperm(inputs.size()[0])
            inputs=inputs[idx]
            outputs=outputs[idx]
            train_gen = data_gen(inputs,outputs,self.batch_size)
            
            epoch_loss = 0

            for batch_input,batch_output in train_gen:
                batch_importance = len(batch_output) / len(outputs)
                batch_predictions, mu, logvar = self.model(batch_input)
                loss = ...
                if l2:
                    l2_lambda = 1e-5
                    l2_norm = sum(p.pow(2.0).sum() for p in self.model.parameters())
                    loss = loss + l2_lambda * l2_norm
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.detach().cpu().item() /self.batch_size * batch_importance
                
            val_loss = self.evaluate(val_inputs, val_outputs, print_loss=False)
            if n_epoch % 10 ==0 and not silent: 
                print("Epoch %d/%d - Loss: %.3f " % (n_epoch + 1, self.epoch, epoch_loss))
                print("              Val_loss: %.3f" % (val_loss))
            losses.append(epoch_loss)
            val_losses.append(val_loss)
            if early_stop:
                if val_loss < lowest_val_loss:
                    lowest_val_loss = val_loss
                    weights = self.model.state_dict()
        if draw_curve:
            plt.figure()
            plt.plot(np.arange(self.epoch) + 1,losses,label='Training loss')
            plt.plot(np.arange(self.epoch) + 1,val_losses,label='Validation loss')
            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.legend()
        
        if early_stop:
            self.model.load_state_dict(weights)    

        return {"losses": losses,  "val_losses": val_losses}
        
    def evaluate(self, inputs, outputs, print_loss=True):
        if torch.is_tensor(inputs):
            inputs = self.input_transform(inputs)
        else:
            inputs = self.input_transform(torch.tensor(inputs, dtype=torch.float))
            outputs = torch.tensor(outputs, dtype=torch.int64)
        self.model.eval()
        gen = data_gen(inputs,outputs,self.batch_size)
        losses = 0

        for batch_input,batch_output in gen:
            batch_importance = len(batch_output) / len(outputs)
            with torch.no_grad():
                batch_predictions, mu, logvar = self.model(batch_input)
                loss= ...
            
            losses += loss.detach().cpu().item()/self.batch_size * batch_importance

        if print_loss:
            print("Loss: %.3f" % losses)
        return losses

In [43]:
from sklearn.model_selection import train_test_split, KFold
from torchsummary import summary
def train_model(model_func,Xs,ys,test_Xs,test_ys,epochs,draw_curve=True,early_stop=False,batchsize=128, optimizer='adam',lr=1e-3,l2=False,input_shape=(-1,1,32,32)):
    train_Xs, val_Xs, train_ys, val_ys = train_test_split(Xs, ys, test_size=1/3, random_state=0)
    model=model_func()
    summary(model,input_shape[1:])

    print(f"{model_func.__name__} parameters:", sum([len(item.flatten()) for item in model.parameters()]))
            
    trainer = Trainer(model, optimizer, lr, epochs, batchsize, lambda x: x.reshape(input_shape))
    log=trainer.train(train_Xs, train_ys,val_Xs,val_ys,early_stop=early_stop,l2=l2)

    if draw_curve:
        plt.figure()
        plt.plot(log["losses"], label="losses")
        plt.plot(log["val_losses"], label="validation_losses")
        plt.legend()
        plt.title(f'loss')

    # Report result for this fold
    if early_stop:
        report_idx= np.argmin(log["val_losses"])      
    else:
        report_idx=-1
    test_loss=trainer.evaluate(test_Xs,test_ys,print_loss=False)
    print("Test loss:",test_loss)
    return model


In [54]:
def reconstruct(vae,data_gen):
    """given a VAE model, plot original data and reconstructed data from VAE"""
    inp = next(data_gen)[0]
    print('Original Data:')
    plot_digits(inp)
    with torch.no_grad():
        reconst,mu,log_var = vae(torch.tensor(inp,dtype=torch.float))

    print('Reconstructed Data:')
    plot_digits(reconst.detach().numpy()) 
    
def plot_digits(data):
    #plot 100 digit. data shape(100,32,32)
    fig, ax = plt.subplots(10, 10, figsize=(12, 12),
                           subplot_kw=dict(xticks=[], yticks=[]))
    fig.subplots_adjust(hspace=0.1, wspace=0.1)
    for i, axi in enumerate(ax.flat):
        im = axi.imshow(data[i].reshape(32, 32), cmap=plt.get_cmap('gray'))
        im.set_clim(0, 1)

