In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import matplotlib.pyplot as plt
import torch
from visualizer import visualize_random_image, visualize_loss_acc_plot
from train import setup_train_args, build, test

  @register_model
  @register_model
  @register_model
  @register_model
  @register_model
  @register_model
  @register_model


In [2]:
datasets = ["Kvasir", "CVC", "both"]
dataset = datasets[2]

In [3]:
from train import train_epoch
def train_student(args):
    (device,
     train_dataloader,
     test_dataloader,
     val_dataloader,
     Dice_loss,
     BCE_loss,
     KLT_loss,
     perf,
     student_model,
     _,
     optimizer,
     alpha,
     temperature) = build(args)


    loss_storage = []
    # performance metric is dice score
    perf_storage = []

    if not os.path.exists("./Trained_student_models/trained_alone"):
        os.makedirs("./Trained_student_models/trained_alone")
        
    model_name = student_model.__class__.__name__
    my_path = f"Trained_student_models/trained_alone/{model_name}_temp_{temperature}_alpha_{alpha}_lr_{args_train_student.lr}" + args.dataset
    prev_best_test = None
    if args.lrs == "true":
        if args.lrs_min > 0:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.5, min_lr=args.lrs_min, verbose=True, patience=6
            )
        else:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.5, verbose=True, patience=6
            )
    for epoch in range(1, args.epochs + 1):
        try:
            loss = train_epoch(student_model, device, train_dataloader,
                                       optimizer, epoch, Dice_loss, BCE_loss)

            test_measure_mean, test_measure_std = test(
                student_model, device, val_dataloader, epoch, perf
            )
            # add means and loss for vis
            loss_storage.append(loss)
            perf_storage.append(test_measure_mean)

        except KeyboardInterrupt:
            print("Training interrupted by user")
            sys.exit(0)
        if args.lrs == "true":
            scheduler.step(test_measure_mean)
        # if prev_best_test == None or test_measure_mean > prev_best_test:
        #     print("Saving...")
        #     torch.save(
        #         {
        #             "epoch": epoch,
        #             "model_state_dict": student_model.state_dict()
        #             if args.mgpu is False
        #             else student_model.module.state_dict(),
        #             "optimizer_state_dict": optimizer.state_dict(),
        #             "loss": loss,
        #             "test_measure_mean": test_measure_mean,
        #             "test_measure_std": test_measure_std,
        #         },
        #         my_path + ".pt",
        #     )
        #     prev_best_test = test_measure_mean
        if epoch == args.epochs: 
            visualize_random_image(test_dataloader, student_model)
    
    return loss_storage, perf_storage, my_path

In [4]:
args_train_student = setup_train_args(my_model="unet_inception", dataset=dataset,
                                      data_root="./data_root/", batch_size=8, epochs=60, learning_rate=0.001)
losses, perfs, path = train_student(args_train_student)
visualize_loss_acc_plot(losses, acc_list=perfs, path=path+".png")

UNetInception(
  (normalize): Normalize()
  (downs): ModuleList(
    (0): DoubleConv(
      (double_conv): Sequential(
        (0): InceptionBlock(
          (conv1): Conv2d(3, 21, kernel_size=(1, 1), stride=(1, 1))
          (conv3): Conv2d(3, 21, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv5): Conv2d(3, 22, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        )
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): InceptionBlock(
          (conv1): Conv2d(64, 21, kernel_size=(1, 1), stride=(1, 1))
          (conv3): Conv2d(64, 21, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv5): Conv2d(64, 22, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        )
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): DoubleConv(
      (double_conv): Sequential(
        (0

AttributeError: 'types.SimpleNamespace' object has no attribute 'epoch'

In [None]:
args_train_student = setup_train_args(my_model="unet_attention", dataset=dataset,
                                      data_root="./data_root/", batch_size=8, epochs=60, learning_rate=0.001)
losses, perfs, path = train_student(args_train_student)
visualize_loss_acc_plot(losses, acc_list=perfs, path=path+".png")

In [None]:
from train import train_epoch
def train_student(args, unet_size):
    (device,
     train_dataloader,
     test_dataloader,
     val_dataloader,
     Dice_loss,
     BCE_loss,
     KLT_loss,
     perf,
     student_model,
     _,
     optimizer,
     alpha,
     temperature) = build(args)


    loss_storage = []
    # performance metric is dice score
    perf_storage = []

    if not os.path.exists("./Trained_student_models/trained_alone"):
        os.makedirs("./Trained_student_models/trained_alone")
        
    model_name = student_model.__class__.__name__
    my_path = f"Trained_student_models/trained_alone/{model_name}{unet_size}_temp_{temperature}_alpha_{alpha}_lr_{args_train_student.lr}" + args.dataset
    prev_best_test = None
    if args.lrs == "true":
        if args.lrs_min > 0:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.5, min_lr=args.lrs_min, verbose=True, patience=6
            )
        else:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.5, verbose=True, patience=6
            )
    for epoch in range(1, args.epochs + 1):
        try:
            loss = train_epoch(student_model, device, train_dataloader,
                                       optimizer, epoch, Dice_loss, BCE_loss)

            test_measure_mean, test_measure_std = test(
                student_model, device, val_dataloader, epoch, perf
            )
            # add means and loss for vis
            loss_storage.append(loss)
            perf_storage.append(test_measure_mean)

        except KeyboardInterrupt:
            print("Training interrupted by user")
            sys.exit(0)
        if args.lrs == "true":
            scheduler.step(test_measure_mean)
        # if prev_best_test == None or test_measure_mean > prev_best_test:
        #     print("Saving...")
        #     torch.save(
        #         {
        #             "epoch": epoch,
        #             "model_state_dict": student_model.state_dict()
        #             if args.mgpu is False
        #             else student_model.module.state_dict(),
        #             "optimizer_state_dict": optimizer.state_dict(),
        #             "loss": loss,
        #             "test_measure_mean": test_measure_mean,
        #             "test_measure_std": test_measure_std,
        #         },
        #         my_path + ".pt",
        #     )
        #     prev_best_test = test_measure_mean
        if epoch == args.epochs: 
            visualize_random_image(test_dataloader, student_model)
    
    return loss_storage, perf_storage, my_path

In [None]:
args_train_student = setup_train_args(my_model="unet512", dataset=dataset,
                                      data_root="./data_root/", batch_size=8, epochs=60, learning_rate=0.001)
losses, perfs, path = train_student(args_train_student, 512)
visualize_loss_acc_plot(losses, acc_list=perfs, path=path+".png")

In [None]:
args_train_student = setup_train_args(my_model="unet512", dataset=dataset,
                                      data_root="./data_root/", batch_size=8, epochs=100, learning_rate=0.001)
losses, perfs, path = train_student(args_train_student, 512)
visualize_loss_acc_plot(losses, acc_list=perfs, path=path+".png")