# DDSM classification

In [None]:
from __future__ import print_function

import os
import torch
import torch.optim as optim
from torch.autograd import Variable

# Internal dependency
import classification as ddsm_classify
from classification import DDSMDataset
from classification import MyResNet

## Set up training variables

In [None]:
# Change these depending on what machine this notebook is running on
# data_dir = "/Users/yairschiff/Development/PycharmProjects/ComputerVision/Project/data/"
data_dir = "/scratch/jtb470/DDSM/data"
# model_res_dir = "/Users/yairschiff/Development/PycharmProjects/ComputerVision/Project/model_results_stage2/"
model_res_dir = "/scratch/yzs208/CV_Project/model_results_v3_s2"
batch_size = 2
epochs = 20
lr = 0.01
# checkpoint = "/Users/yairschiff/Development/PycharmProjects/ComputerVision/Project/model_results/model_results20.pth"
checkpoint = "/scratch/yzs208/CV_Project/model_results_v2/model_results20.pth"
train_heads = checkpoint == ""
log_interval = 50
torch.manual_seed(1)

## Set device

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device != torch.device("cpu"):
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
print("Device: {}".format(device))

## Load data

In [None]:
train_loader = torch.utils.data.DataLoader(DDSMDataset(data_dir, dataset="train", exclude_brightened=True),
                                           batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = torch.utils.data.DataLoader(DDSMDataset(data_dir, dataset="val", exclude_brightened=True),
                                         batch_size=batch_size, shuffle=True, num_workers=1)

## Load model

In [None]:
model = MyResNet("resnet18", 3, only_train_heads=train_heads)
if checkpoint != "":
    state_dict = torch.load(checkpoint) if torch.cuda.is_available() else torch.load(checkpoint, map_location='cpu')
    model.load_state_dict(state_dict)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

## Run training and validation

In [None]:
if not os.path.isdir(model_res_dir):
    print(model_res_dir + " not found: making directory for results")
    os.mkdir(model_res_dir)
for epoch in range(1, epochs + 1):
    ddsm_classify.train(model, train_loader, optimizer, device, epoch, log_interval)
    ddsm_classify.validation(model, val_loader, device)
    model_file = os.path.join(model_res_dir, "model_stage" +
                              str(1 if train_heads else 2) + "_" + str(epoch) + ".pth")
    torch.save(model.state_dict(), model_file)
    print("\nSaved model to " + model_file + ".")