### Import libraries

In [1]:
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import time, os

from models.dataset_classification_vindr_supervised import MakeDataset_VinDr_classification
from models.unet_score_model_conditional import UNetScoreModel_conditional
from models.likelihood_computation import ode_likelihood
from models.utils import get_lr
from models.vpsde import VPSDE

### Cuda setup

In [2]:
## Devices
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print("Available GPUs: ", torch.cuda.device_count())
# print("Current device ID: ", torch.cuda.current_device())

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

try: 
    if device.type == 'cuda':
        print(torch.cuda.get_device_name(0))
        print('Memory Usage:')
        print('Allocated:', round(torch.cuda.memory_allocated(1)/1024**3,1), 'GB')
        print('Cached:   ', round(torch.cuda.memory_reserved(1)/1024**3,1), 'GB')
except: 
    print("No second GPU Found")

cuda
Available GPUs:  2
NVIDIA A30
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB
NVIDIA A30
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


### Hyperparameters

In [3]:
error_tolerance = 1e-7

n_epochs = 1
batch_size = 32
lr = 1e-4 
beta_min=0.1
beta_max=20
target_size = 384
channels= 3

In [4]:
dataset= "VinDr"
sde_type= "VPSDE"

In [5]:
image_dir = "../ST_Mammo/dataset/VinDr_Mammo/Images_Processed_CLAHE"
label_dir_csv = "../ST_Mammo/dataset/VinDr_Mammo/breast-level_annotations.csv"

### Dataset

In [6]:
transform = transforms.Compose([transforms.Resize((target_size, target_size)), 
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.RandomVerticalFlip(p=0.5),
                                transforms.ToTensor()])
transform_test = transforms.Compose([transforms.Resize((target_size, target_size)),
                                transforms.ToTensor()])

train_dataloader = MakeDataset_VinDr_classification(image_dir = image_dir,
                                                        label_dir_csv = label_dir_csv,
                                                        transform=transform,
                                                        mode='train',
                                                        target_size= target_size)

test_dataloader = MakeDataset_VinDr_classification(image_dir = image_dir,
                                                        label_dir_csv = label_dir_csv,
                                                        transform=transform_test,
                                                        mode='test',
                                                        target_size= target_size)

Total benign and malignant images in train: 15209 790
Total benign and malignant images in test: 3802 198
Train and Test files: 15999, 4000. Current mode: train
Total benign and malignant images in train: 15209 790
Total benign and malignant images in test: 3802 198
Train and Test files: 15999, 4000. Current mode: test


In [7]:
# Create data loaders for training and testing

train_loader = DataLoader(train_dataloader, batch_size=batch_size, shuffle=True, num_workers=64)
test_loader = DataLoader(test_dataloader, batch_size=batch_size, num_workers=64)

### Model initialization

In [8]:
sde_model = VPSDE(
    beta_min=beta_min,
    beta_max=beta_max)

score_model = torch.nn.DataParallel(UNetScoreModel_conditional(
                        marginal_prob_std= sde_model.marginal_prob,
                        sde= sde_type,
                        n_classes= 2,
                    )).to(device)

In [9]:
def loss_fn(model, 
            sde_model, 
            x, 
            y,
            eps=1e-5, 
            mode='train'):
    """The loss function for training score-based generative models.

    Args:
    model: A PyTorch model instance that represents a 
        time-dependent score-based model.
    x: A mini-batch of training data.    
    marginal_prob_std: A function that gives the standard deviation of 
        the perturbation kernel.
    eps: A tolerance value for numerical stability.
    """
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps  
    z = torch.randn_like(x)
    mean, std = sde_model.marginal_prob(x, random_t)
    perturbed_x = (mean + std[:, None, None, None] * z).to(device)
    score = model(perturbed_x, random_t, y)
    if mode == 'test':
        loss_score = torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3))
    else:
        loss_score = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
    return loss_score

### Training

In [10]:
optimizer = Adam(score_model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.2)
folder_name = "saved_checkpoints"
if not os.path.exists(folder_name):
    os.makedirs(folder_name)

In [11]:
for epoch in range(1, n_epochs + 1):
    epoch_time = time.time()
    curr_lr = get_lr(optimizer)
    print("#### Epoch: ", epoch, " and current learning rate: ", curr_lr, "####")
    avg_loss = 0.
    num_items = 0
    for i, data in enumerate(train_loader):
        x, y = data
        x, y = x.to(device), y.to(device)
        loss = loss_fn(
            model= score_model,
            sde_model= sde_model,
            x= x, 
            y= y,
            mode= 'train'
        )
        optimizer.zero_grad()
        loss.backward()    
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
        if (num_items + 1) % 100 == 0:
            print("training complete for items: ", num_items)
    print("Average Loss: ", avg_loss / num_items)
    print("Time taken: ", time.time() - epoch_time)
    if epoch % 20 == 0:
        scheduler.step()

#### Epoch:  1  and current learning rate:  0.0001 ####
training complete for items:  15999
Average Loss:  107719.90288382281
Time taken:  143.41941714286804
