This section implements the overall Meta rPPG system

In [None]:
# Import general libraries
import ipynb
import itertools
import numpy as np
import os
import pdb
import pickle
from scipy import signal
import time
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim

In [None]:
# Import classes created in other files
from ipynb.fs.full.Submodel_MetarPPG import Convolutional_Encoder, rPPG_Estimator, Synthetic_Gradient_Generator
from ipynb.fs.full.Loss_MetarPPG import ordLoss
from ipynb.fs.full.Data_MetarPPG import butter_bandpass_filter

In [None]:
class meta_rppg(nn.Module):
    def _init_(self, opt, isTrain, continue_train=False, norm_layer=nn.BatchNorm2d):
        # Supercharge the meta_rppg class so that it can inherit from itself
        super(meta_rppg, self)._init_()
        
        # Save the directory
        
        # Initialize variables
        self.opt = opt
        self.isTrain = isTrain
        self.continue_train = continue_train
        self.threshold = 0.5
        self.gpu_ids = opt.gpu_ids
        if self.gpu_ids:
            self.device = torch.device('cuda:{}'.format(self.gpu_ids[0]))
        else:
            self.device = torch.device('cpu')
        
        self.prototype = torch.zeros(120)
        self.h = torch.zeros(2*opt.lstm_num_layers, opt.batch_size, 60).to(self.device)
        self.c = torch.zeros(2*opt.lstm_num_layers, opt.batch_size, 60).to(self.device)
        
        # Initializing the three sub models
        self.con = Convolutional_Encoder(no_input_channel=3, isTrain=self.isTrain, device=self.device)
        self.rppg = rPPG_Estimator(no_input_channel=120, num_layers=opt.lstm_num_layers, isTrain=self.isTrain, device=self.device, h=self.h, c=self.c)
        self.syn = Synthetic_Gradient_Generator(number_input_channels=120, isTrain=self.isTrain, device=self.device)
        
        # Set sub models state to device
        self.con.to(self.device)
        self.rppg.to(self.device)
        self.syn.to(self.device)
        
        # Define the overall model
        self.model = [self.con, self.rppg.to(self.device), self.syn.to(self.device)]
        
        # Initialize the loss variables
        self.few_shotloss = 0.0
        self.ordloss = 0.0
        self.gradientloss = 0.0
        
        # Set the loss criterions
        self.criterion1 = torch.nn.MSELoss()
        self.criterion2 = ordLoss()
        self.criterion3 = torch.nn.MSELoss()
        
        
        momentum = 0.9
        weight_decay = 5e-4
        # Set optimizers for each individual sub model
        self.optimizer_con = torch.optim.SGD(self.con.parameters(), opt.lr, momentum, weight_decay)
        self.optimizer_rppg = torch.optim.SGD(self.rppg.parameters(), opt.lr, momentum, weight_decay)
        self.optimizer_syn = torch.optim.SGD(self.syn.parameters(), opt.lr, momentum, weight_decay)
        
        # set the optimization for the psi variable which is the hyperparameter for the update of our meta-learning system
        if self.opt.adapt_position == 'extractor':
            self.optimizer_psi = torch.optim.SGD(self.con.parameters(), opt.lr*1e-2, momentum, weight_decay)
        elif self.opt.adapt_position == 'estimator':
            self.optimizer_psi = torch.optim.SGD(self.rppg.parameters(), opt.lr*1e-2, momentum, weight_decay)
        elif self.opt.adapt_position == 'both':
            self.optimizer_psi = torch.optim.SGD(itertools.chain(self.con.parameters(), self.rppg.parameters()), opt.lr*1e-2, momentum, weight_decay)
        
        # Set the schedulers for each sub model and hyperparameter
        self.scheduler_con = optim.lr_scheduler.CosineAnnealingLR(self.optimizer_con, T_max=5, eta_min=0.1*opt.lr)
        self.scheduler_rppg = optim.lr_scheduler.CosineAnnealingLR(self.optimizer_rppg, T_max=5, eta_min=0.1*opt.lr)
        self.scheduler_syn = optim.lr_scheduler.CosineAnnealingLR(self.optimizer_syn, T_max=5, eta_min=0.1*opt.lr)
        self.scheduler_psi = optim.lr_scheduler.CosineAnnealingLR(self.optimizer_psi, T_max=5, eta_min=0.1*1e-2*opt.lr)
        
    
    # Show the number of parameters in the system
    def print_networks(self, print_net):
        print('----------- Networks initialized -------------')
        num = 0
        for param in self.con.parameters():
            num += param.numel()
        for param in self.rppg.parameters():
            num += param.numel()
        for param in self.syn.parameters():
            num += param.numel()
        if print_net:
            print(self.model)
        print('Total number of parameters : %.3f M' % (num/1e6))
        print('----------------------------------------------')
        
    
    # Set the inputs
    def set_input(self, input):
        self.input = input['input']
        self.rPPG = input['rPPG']
        self.center = input['center']
        
    
    # Forward propagation
    def forward(self, x):
        self.inter = self.con(x)
        # find condition and estimation
        self.condition, self.estimate = self.rppg(self.inter)
        # Set the gradient variable based on phase 
        if self.opt.adapt_position == 'extractor':
            self.gradient = self.syn(self.inter.detach())
        elif self.opt.adapt_position == 'estimator':
            self.gradient = self.syn(self.predict.detach())
        elif self.opt.adapt_position == 'both':
            self.gradient1 = self.syn(self.inter.detach())
            self.gradient2 = self.syn(self.predict.detach())
            
    
    # update theta - feature extractor variable
    def theta_update(self, epoch):
        inter = self.con(self.input.to(self.device))
        condition, estimate = self.rppg(inter)
        
        # initialize the loss
        few_shotloss = self.criterion1(self.prototype.expand(self.opt.batch_size, 60, 120), inter)
        ordloss = self.criterion2(estimate, self.rPPG.to(self.device))
        
        self.optimizer_con.zero_grad()
        loss = few_shotloss + ordloss
        loss.backward()
        self.optimizer_con.step()
        
        # update weight based on current phase
        if self.opt.adapt_position == "extractor":
            for i in range(self.opt.fewshots):
                inter = self.con(self.input.to(self.device))
                condition, estimate = self.rppg(inter)
                inter_grad = self.syn(inter.detach())
                self.optimizer_psi.zero_grad()
                grad = torch.autograd.grad(outputs=inter, inputs=self.con.parameters(), grad_outputs=inter_grad, create_graph=False, retain_graph=False)
                torch.autograd.backward(self.con.parameters(), grad_tensors=grad, retain_graph=False, create_graph=False)
                self.optimizer_psi.step()
            self.gradient = inter_grad.detach().clone()
        elif self.opt.adapt_position == "estimator":
            for i in range(self.opt.fewshots):
                inter = self.con(self.input.to(self.device))
                condition, estimate = self.rppg(inter)
                predict_grad = self.syn(predict.detach())
                self.optimizer_psi.zero_grad()
                grad = torch.autograd.grad(outputs=predict, inputs=self.rppg.parameters(), grad_outputs=predict_grad, create_graph=False, retain_graph=False)
                torch.autograd.backward(self.rppg.parameters(), grad_tensors=grad, retain_graph=False, create_graph=False)
                self.optimizer_psi.step()
            self.gradient = predict_grad.detach().clone()
        elif self.opt.adapt_position == "both":
            for i in range(self.opt.fewshots):
                inter = self.con(self.input.to(self.device))
                condition, estimate = self.rppg(inter)
                inter_grad = self.syn(inter.detach())
                predict_grad = self.syn(predict.detach())
                self.optimizer_psi.zero_grad()
                grad = torch.autograd.grad(outputs=inter, inputs=self.con.parameters(), grad_outputs=inter_grad, create_graph=False, retain_graph=False)
                torch.autograd.backward(self.con.parameters(), grad_tensors=grad, retain_graph=False, create_graph=False)
                grad = torch.autograd.grad(outputs=predict, inputs=self.rppg.parameters(), grad_outputs=predict_grad, create_graph=False, retain_graph=False)
                torch.autograd.backward(self.rppg.parameters(), grad_tensors=grad, retain_graph=False, create_graph=False)
                self.optimizer_psi.step()
            self.gradient = predict_grad.detach().clone()
        
        # output the variables
        self.few_shotloss = few_shotloss.detach().clone()
        self.ordloss = ordloss.detach().clone()
        self.inter = inter.detach().clone()
        
        
    # Similarly update psi and phi based on current phase
    def psi_phi_update(self, epoch):
        if self.opt.adapt_position == "extractor":
            inter = self.con(self.input.to(self.device))
            condition, estimate = self.rppg(inter)
            inter_grad = self.syn(inter.detach())
            
            inter.retain_grad()
            few_shotloss = self.criterion1(self.prototype.expand(self.opt.batch_size, 60, 120), inter)
            ordloss = self.criterion2(estimate, self.rPPG.to(self.device))
            loss = few_shotloss + ordloss
            
            self.optimizer_con.zero_grad()
            self.optimizer_rppg.zero_grad()
            loss.backward()
            self.optimizer_con.step()
            self.optimizer_rppg.step()
            
            gradloss = self.criterion3(inter_grad, inter.grad)
            self.optimizer_syn.zero_grad()
            gradloss.backward()
            self.optimizer_syn.step()
            self.gradloss = gradloss.detach().clone()
            
        elif self.opt.adapt_position == "estimator":
            inter = self.con(self.input.to(self.device))
            condition, estimate = self.rppg(inter)
            predict_grad = self.syn(predict.detach())
            
            predict.retain_grad()
            few_shotloss = self.criterion1(self.prototype.expand(self.opt.batch_size, 60, 120), inter)
            ordloss = self.criterion2(estimate, self.rPPG.to(self.device))
            loss = few_shotloss + ordloss
            
            self.optimizer_con.zero_grad()
            self.optimizer_rppg.zero_grad()
            loss.backward()
            self.optimizer_con.step()
            self.optimizer_rppg.step()
            
            gradloss = self.criterion3(predict_grad, predict.grad)
            self.optimizer_syn.zero_grad()
            gradloss.backward()
            self.optimizer_syn.step()
            self.gradloss = gradloss.detach().clone()
            
        elif self.opt.adapt_position == "both":
            inter = self.con(self.input.to(self.device))
            condition, estimate = self.rppg(inter)
            inter_grad = self.syn(inter.detach())
            predict_grad = self.syn(predict.detach())
            
            inter.retain_grad()
            predict.ratin_grad()
            few_shotloss = self.criterion1(self.prototype.expand(self.opt.batch_size, 60, 120), inter)
            ordloss = self.criterion2(estimate, self.rPPG.to(self.device))
            loss = few_shotloss + ordloss
            
            self.optimizer_con.zero_grad()
            self.optimizer_rppg.zero_grad()
            loss.backward()
            self.optimizer_con.step()
            self.optimizer_rppg.step()
            
            gradloss = self.criterion3(inter_grad, inter.grad) + self.criterion3(predict_grad, predict.grad)
            self.optimizer_syn.zero_grad()
            gradloss.backward()
            self.optimizer_syn.step()
            self.gradloss = gradloss.detach().clone()
            
        # output the variables
        self.condition = condition.detach().clone()
        self.estimate = estimate.detach().clone()
        self.ordloss = ordloss.detach().clone()
        
    # Prototypical distance
    def update_prototype(self):
        proto = torch.zeros(120).to(self.device)
        h_tmp = torch.zeros(2*opt.lstm_num_layers, opt.batch_size, 60).to(self.device)
        c_tmp = torch.zeros(2*opt.lstm_num_layers, opt.batch_size, 60).to(self.device)
        self.rppg.feed_hc([self.h, self.c])
        
        self.forward(self.input.to(self.device))
        proto += self.inter.data.mean(axis=[0,1])
        h_tmp += self.rppg.h.data
        c_tmp += self.rppg.c.data
        
        # update the prototypical distance for first update and then every other update
        if torch.sum(self.prototype) == 0:
            self.prototype = proto
            (self.h, self.c) = (h_tmp, c_tmp)
        else:
            self.prototype = 0.8*self.prototype + 0.2*proto
            (self.h, self.c) = (0.8*self.h + 0.2*h_tmp, 0.8*self.c + 0.2*c_tmp)
            
    # Initialize weights
    def init_weights(net1, net2, init_type='normal', init_gain=0.02):
        net1.apply(init_func)
        net2.apply(init_func)
        
        
    # Setup network
    def setup(self, opt):
        self.init_weights(self.con, self.rppg)
        if self.continue_train:
            self.load_networks(opt.load_file)
            self.thres = 0.01
        if not self.isTrain:
            self.load_networks(opt.load_file)
        self.print_networks(opt.print_net)
        
    
    # Save network to disc - good practice in case of any accidents while running
    def save(self, suffix):
        save_filename = '%s_%s.pth' % (suffix, self.opt.name)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save({'Con': self.con.state_dict(), 'rPPG': self.rppg.state_dict(), 'SGG': self.syn.state_dict(),
                    'proto': self.prototype.cpu(), 'h': self.h.data.cpu(), 'c': self.c.data.cpu()},
                    save_path1)
    
    def get_current_results(self, isTest):
        return self.condition[-1].cpu().clone(), self.rPPG[-1].cpu().clone()
                
    def get_current_losses(self, isTest):
        if isTest:
            return self.new_ordless
        else:
            return [self.few_shotloss, self.gradloss, self.ordloss]
    
    def train(self):
        self.con.train()
        self.rppg.train()
        self.syn.train()
    
    def few_shotloss_test(self, epoch):
        momentum=0.9
        weight_decay=5e-4
        conv = pickle.loads(pickle.dumps(self.con))
        optim = torch.optim.SGD(A.parameters(), self.opt.lr*1e-2, momentum, weight_decay)
        
        for i in range(self.opt.few_shots):
            optim.zero_grad()
            inter = conv(self.input[i].unsqueeze(0).to(self.device))
            inter_grad = self.syn(inter)
            grad = torch.autograd.grad(outputs=inter, inputs=conv.parameters(), grad_outputs=inter_grad, create_graph=False, retain_graph=False)
            torch.autograd.backward(conv.parameters(), grad_tensors=grad, retain_graph=False, create_graph=False)
            optim.step()
        for i in range(self.opt.few_shots):
            optim.zero_grad()
            inter = conv(self.input[i].unsqueeze(0).to(self.device))
            loss = self.criterion1(inter, self.prototype.expand(1, 60, 120))
            loss.backward()
            optim.step()
        with torch.no_grad():
            tmp_h = self.rppg.h
            tmp_c = self.rppg.c
            self.rppg.feed_hc([self.h, self.c])
            data = self.input[self.opt.fewshots:]
            inter = conv(data.to(self.device))
            self.decision, self.predict = self.rppg(inter)
            self.rppg.feed_hc([tmp_h, tmp_c])
            
        self.new_ordless = self.criterion2(self.predict[0].unsqueeze(0), self.rPPG[0].unsqueeze(0).to(self.device))
        
    
    # Load network from disk
    def load_networks(self, suffix):
        load_filename = '%s_%s.pth' % (suffix, self.opt.name)
        load_path = os.path.join(self.load_dir, load_filename)
        model_dict = torch.load(load_path)
        self.con.load_state_dict(model_dict['Con'])
        self.rppg.load_state_dict(model_dict['rPPG'])
        self.syn.load_state_dict(model_dict['SGG'])
        self.prototype = model_dict['proto'].to(self.device)
        self.h = model_dict['h'].to(self.device)
        self.c = model_dict['c'].to(self.device)
        
    # Update learning rate
    def update_lr(self, epoch):
        self.scheduler_con.step()
        self.scheduler_rppg.step()
        self.scheduler_syn.step()
        self.scheduler_psi.step()
        lr = self.optimizer_rppg.param_groups[0]['lr']
        return lr