In [19]:
import numpy as np
import torch
import torchvision 
import networks


class Trainer():
    def __init__(self, network, network_name, device, train_loader, test_loader, optimizer, loss_fuction, num_epochs, save_model_path, save_optim_path):
        """Taking in input all the parameters necessary for training"""
        self.network=network
        self.network_name=network_name
        self.device=device
        self.train_loader=train_loader
        self.test_loader=test_loader
        self.optimizer=optimizer
        self.num_epochs=num_epochs
        self.loss_function=loss_function
        self.save_model_path=save_model_path
        self.save_optim_path=save_optim_path

        
        
    def validation_epoch(self):
            self.network.eval()  # evaluation mode, equivalent to "network.train(False)""
            val_loss = 0
            rmse_epoch=[]
            rmse_log_epoch=[]
            abs_rel_epoch=[]
            sq_rel_epoch=[]
            a1_epoch=[]
            a2_epoch=[]
            a3_epoch=[]


            with torch.no_grad(): # No need to track the gradients
                for image, depth_map in self.test_dataloader:
                    #Moving to GPU
                    image.to(self.device)
                    d = depth_map.to(self.device)
                    #Applying the necessary transforms

                    image = image.type(torch.cuda.FloatTensor)
                    d=d.type(torch.cuda.FloatTensor)

                    #Going through the network
                    if self.network_name=='midas':
                        inv_d_hat = self.network(image.squeeze(1))
                    else:
                        inv_d_hat = self.network(image.squeeze())
                    
                    inv_d=1/d
                    inv_d[inv_d==np.inf]=0
                    
                    #Computing the loss, storing it
                    if self.network_name=='midas':
                        loss=loss_function(inv_d_hat.squeeze(),  inv_d, mask).float()
                    else:
                        loss=loss_function(inv_d_gat, inv_d, mask).float()
                    
                    val_loss += loss.item()
                    

    def train_epoch(self):
        "Trains the model for one epoch"
        self.network.train()
        total_loss_epoch=0


        # Iterate the dataloader (We do not need the label value which is 0 here, the depth maps are the labels)
        iter = 0
        for image, depth_map in self.train_loader:   
            #Moving to GPU
            image.to(self.device)
            d = depth_map.to(self.device)

            #Right size and type
            image=image.squeeze()
            image = image.type(torch.cuda.FloatTensor)
            d=d.type(torch.cuda.FloatTensor)


            #Going through the network
            inv_d_hat = network(image)

                

            #Computing the loss, storing it
            mask=d>0.0

            #Not taking into account the zones where there is no data available (mask)

            #From depth map to the inverse depth map
            inv_d=1/d
            inv_d[inv_d==np.inf]=0
            if self.network_name=='midas':
                loss=self.loss_function(inv_d_hat.squeeze(),  inv_d, mask).float()
            else:
                loss=self.loss_function(inv_d_hat.squeeze(),  inv_d, mask).float()


            #Store batch loss
            total_loss_epoch += loss.item()

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()  
            self.optimizer.step() 

        return total_loss_epoch
    
    
    
    def full_training_process(self, print_metrics=True, lr_decay=False):
        train_losses = []
        val_loss, abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3=self.validation_epoch(self.network, self.device, self.loss_function, self.test_loader)
        val_losses = [val_loss] #first evaluation before training
        weight_decay=1e-4
        lr=1e-4
        for epoch in range(self.num_epochs):
            if lr_decay:
                if epoch>50:
                    if epoch%10==0:
                        lr/=5
                        optim=torch.optim.Adam(self.network.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-20)

            #Optimising the model 
            train_loss_epoch= self.train_epoch(self.network, self.device, self.train_loader, self.optim, self.loss_function)
            

            #Assessing the model performance during training
            val_loss_epoch, abs_rel_epoch, sq_rel_epoch, rmse_epoch, rmse_log_epoch, a1_epoch, a2_epoch, a3_epoch  = self.validation_epoch(self.network, self.device, self.loss_function, self.test_loader)
            #all those values are floats
            

            #Storing 
            train_losses.append(train_loss_epoch)
            val_losses.append(val_loss_epoch)

            abs_rel.append(abs_rel_epoch)
            sq_rel.append(sq_rel_epoch)
            rmse.append(rmse_epoch)
            rmse_log.append(rmse_epoch)
            a1.append(a1_epoch)
            a2.append(a2_epoch)
            a3.append(a3_epoch)


            print(f'\n EPOCH {epoch + 1}/{num_epochs} \t train loss {train_loss_epoch:.4f} \t val loss {val_loss_epoch:.4f} ')
            if print_metrics:
                print(f'\n RMSE {rmse[-1]:.4f} \t RMSLE {rmse_log[-1]:.4f} \t abs rel {abs_rel[-1]:.4f} \t sq rel {sq_rel[-1]:.4f}')

            # Storing the network weights if the RMSE decreazes (otherwise continue training is not interesting)
            if rmse[-1]<lowest_RMSE_value:
                lowest_RMSE_value=rmse[-1]
                torch.save(self.network.state_dict(), self.model_save_path )
                torch.save(self.optim.state_dict(), self.optim_save_path)


        return train_losses, val_losses, rmse, rmse_log, abs_rel, sq_rel, a1, a2, a3

In [20]:
network, network_name, device, train_loader, test_loader, optimizer, loss_function, num_epochs, save_model_path, save_optim_path= None, None, None, None, None, None, None, None, None, None
test_trainer=Trainer(network, network_name, device, train_loader, test_loader, optimizer, loss_function, num_epochs, save_model_path, save_optim_path)

In [29]:
def validation_epoch(self):
		self.network.eval()  # evaluation mode, equivalent to "network.train(False)""
		val_loss = 0
		rmse_epoch=[]
		rmse_log_epoch=[]
		abs_rel_epoch=[]
		sq_rel_epoch=[]
		a1_epoch=[]
		a2_epoch=[]
		a3_epoch=[]


		with torch.no_grad(): # No need to track the gradients
			for image, depth_map in self.test_dataloader:
				#Moving to GPU
				image.to(self.device)
				d = depth_map.to(self.device)
				#Applying the necessary transforms
				#image=image.squeeze()
				image = image.type(torch.cuda.FloatTensor)
				d=d.type(torch.cuda.FloatTensor)

				#Going through the network
				inv_d_hat = self.network(image.squeeze(1))
				inv_d=1/d
				inv_d[inv_d==np.inf]=0






				#Computing the loss, storing it
				mask=d>0.0
				loss=self.loss_function(inv_d_hat.squeeze(),  inv_d, mask).float()

				val_loss += loss.item()

				#Store values for metrics : they have to be computed on depth maps.
				d_hat=1/inv_d_hat

				d_hat[d_hat==np.inf]=0 
				abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 = compute_errors_2(d.detach().cpu(), d_hat.detach().cpu())


				abs_rel_epoch.extend(abs_rel)
				sq_rel_epoch.extend(sq_rel)
				rmse_epoch.extend(rmse)
				rmse_log_epoch.extend(rmse_log)
				a1_epoch.extend(a1)
				a2_epoch.extend(a2)
				a3_epoch.extend(a3)

		return val_loss, np.array(abs_rel_epoch).mean(), np.array(sq_rel_epoch).mean(), np.array(rmse_epoch).mean(), np.array(rmse_log_epoch).mean(), np.array(a1_epoch).mean(), np.array(a2_epoch).mean(), np.array(a3_epoch).mean()
