Knowledge distillation on for UNET
Ref: https://github.com/VaticanCameos99/knowledge-distillation-for-unet/blob/master/train.py

In [16]:
import glob
import torch
# import dataset
import numpy as np
from models.unet_optimize import UNet, SimpleUNet
from models.losses import loss_fn_kd
from models.metrics import dice_loss
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
import os

In [17]:
from ml4floods.models.config_setup import get_default_config
config_fp = "train_models/training_flooding/config.json"
config = get_default_config(config_fp)
config.data_params.data_params='ml4cc_data_lake'
path_to_models = os.path.join(config.model_params.model_folder,config.experiment_name, "model.pt").replace("\\","/")
teacher_weights = path_to_models
teacher_weights

Loaded Config for experiment:  training_flooding
{   'data_params': {   'batch_size': 32,
                       'bucket_id': 'ml4cc_data_lake',
                       'channel_configuration': 'all',
                       'download': {'test': True, 'train': True, 'val': True},
                       'filter_windows': {   'apply': False,
                                             'threshold_clouds': 0.5,
                                             'version': 'v1'},
                       'input_folder': 'S2',
                       'loader_type': 'local',
                       'num_workers': 4,
                       'path_to_splits': '/mnt/d/Flooding/worldfloods_v1_sample',
                       'target_folder': 'gt',
                       'test_transformation': {'normalize': True},
                       'train_test_split_file': '2_PROD/2_Mart/worldfloods_v1_0/train_test_split.json',
                       'train_transformation': {'normalize': True},
                       'win

'train_models/training_flooding/model.pt'

In [21]:
num_channels = config.model_params.hyperparameters.num_channels
num_classes = config.model_params.hyperparameters.num_classes

In [22]:
def fetch_teacher_outputs(teacher, train_loader):
    print('-------Fetch teacher outputs-------')
    teacher.eval().cuda()
    #list of tensors
    teacher_outputs = []
    with torch.no_grad():
        #trainloader gets bs images at a time. why does enumerate(tl) run for all images?
        for i, (img, gt) in enumerate(train_loader):
            print(i, 'i')
            '''img = img[0, :, :, :, :]
            gt = gt[0, :, :, :, :]'''
            if torch.cuda.is_available():
                img = img.cuda()
            img = Variable(img)

            output = teacher(img)
            teacher_outputs.append(output)
    return teacher_outputs

def train_student(student, teacher_outputs, optimizer, train_loader):
    print('-------Train student-------')
    #called once for each epoch
    student.train().cuda()

    summ = []
    for i, (img, gt) in enumerate(train_loader):
        teacher_output = teacher_outputs[i]
        if torch.cuda.is_available():
            img, gt = img.cuda(), gt.cuda()
            teacher_output = teacher_output.cuda()

        img, gt = Variable(img), Variable(gt)
        teacher_output =  Variable(teacher_output)

        output = student(img)

        #TODO: loss is wrong
        loss = loss_fn_kd(output, teacher_output, gt)    

        # clear previous gradients, compute gradients of all variables wrt loss
        optimizer.zero_grad()
        loss.backward()

        # performs updates using calculated gradients
        optimizer.step()
        if i % summary_steps == 0:
            #do i need to move it to CPU?
            
            metric = dice_loss(output, gt)
            summary = {'metric' : metric.item(), 'loss' : loss.item()}
            summ.append(summary)
    
    #print('Average loss over this epoch: ' + np.mean(loss_avg))
    mean_dice_coeff =  np.mean([x['metric'] for x in summ])
    mean_loss = np.mean([x['loss'] for x in summ])
    print('- Train metrics:\n' + '\tMetric:{}\n\tLoss:{}'.format(mean_dice_coeff, mean_loss))
    #print accuracy and loss

def evaluate_kd(student, val_loader):
    print('-------Evaluate student-------')
    student.eval().cuda()

    #criterion = torch.nn.BCEWithLogitsLoss()
    loss_summ = []
    with torch.no_grad():
        for i, (img, gt) in enumerate(val_loader):
            if torch.cuda.is_available():
                img, gt = img.cuda(), gt.cuda()
            img, gt = Variable(img), Variable(gt)

            output = student(img)
            output = output.clamp(min = 0, max = 1)
            loss = dice_loss(output, gt)

            loss_summ.append(loss.item())

    mean_loss = np.mean(loss_summ)
    print('- Eval metrics:\n\tAverage Dice loss:{}'.format(mean_loss))
    return mean_loss

In [25]:
# teacher = UNet(num_channels, num_classes)
# student = SimpleUNet(num_channels, num_classes)

In [26]:
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size = 100, gamma = 0.2)

In [34]:
from models.flooding_model import WorldFloodsModel
teacher_model = WorldFloodsModel(config.model_params)
path_to_models = os.path.join(config.model_params.model_folder,config.experiment_name, "model.pt").replace("\\","/")
print(path_to_models)
from pytorch_lightning.utilities.cloud_io import load
teacher_model.load_state_dict(load(path_to_models))

train_models/training_flooding/model.pt


<All keys matched successfully>

In [None]:
teacher_model