In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os
import matplotlib.pyplot as plt
import torch
from visualizer import visualize_random_image
from train import setup_train_args, build, test

In [None]:
datasets = ["Kvasir", "CVC", "both"]
dataset = datasets[2]

In [None]:
args_train_student = setup_train_args(my_model="unet", temperature=4, alpha=0.9, dataset=dataset,
                                      data_root="./data_root/", batch_size=8, epochs=20, learning_rate=0.01)
from train import student_train_epoch

In [None]:
def train_student(args):
    (device,
     train_dataloader,
     test_dataloader,
     val_dataloader,
     Dice_loss,
     BCE_loss,
     KLT_loss,
     perf,
     student_model,
     teach_model,
     optimizer,
     alpha,
     temperature) = build(args)

    loss_storage = []
    # performance metric is dice score
    perf_storage = []

    if not os.path.exists("./Trained models"):
        os.makedirs("./Trained models")

    prev_best_test = None
    if args.lrs == "true":
        if args.lrs_min > 0:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.5, patience=5, min_lr=args.lrs_min, verbose=True
            )
        else:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.5, patience=5, verbose=True
            )
    for epoch in range(1, args.epochs + 1):
        try:
            loss = student_train_epoch(student_model, teach_model, device, train_dataloader,
                                       optimizer, epoch, Dice_loss, BCE_loss, KLT_loss, temperature, alpha)

            test_measure_mean, test_measure_std = test(
                student_model, device, val_dataloader, epoch, perf
            )
            # add means and loss for vis
            loss_storage.append(loss)
            perf_storage.append(test_measure_mean)

        except KeyboardInterrupt:
            print("Training interrupted by user")
            sys.exit(0)
        if args.lrs == "true":
            scheduler.step(test_measure_mean)
        if prev_best_test == None or test_measure_mean > prev_best_test:
            print("Saving...")
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": student_model.state_dict()
                    if args.mgpu is False
                    else student_model.module.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": loss,
                    "test_measure_mean": test_measure_mean,
                    "test_measure_std": test_measure_std,
                },
                f"Trained_student_models/teacher_supported/{student_model}_temp_{temperature}_alpha_{alpha}" + args.dataset + ".pt",
            )
            prev_best_test = test_measure_mean
        visualize_random_image(test_dataloader, student_model)
    return loss_storage, perf_storage