# DDSM classification

In [1]:
from __future__ import print_function

import os
import torch
import torch.optim as optim
from torch.autograd import Variable

# Internal dependency
import classification as ddsm_classify
from classification import DDSMDataset
from classification import MyResNet

## Set up training variables

In [2]:
# Change these depending on what machine this notebook is running on
# data_dir = "/Users/yairschiff/Development/PycharmProjects/ComputerVision/Project/data/"
data_dir = "/scratch/jtb470/DDSM/data"
# model_res_dir = "/Users/yairschiff/Development/PycharmProjects/ComputerVision/Project/model_results_stage2/"
model_res_dir = "/scratch/yzs208/CV_Project/model_results_s2_nb"
batch_size = 2
epochs = 20
lr = 0.01
# checkpoint = "/Users/yairschiff/Development/PycharmProjects/ComputerVision/Project/model_results/model_results20.pth"
checkpoint = "/scratch/yzs208/CV_Project/model_results_v2/model_results20.pth"
train_heads = checkpoint == ""
log_interval = 50
torch.manual_seed(1)

<torch._C.Generator at 0x2aac3d7c87f0>

## Set device

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device != torch.device("cpu"):
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
print("Device: {}".format(device))

Device: cuda:0


## Load data

In [4]:
train_loader = torch.utils.data.DataLoader(DDSMDataset(data_dir, dataset="train", exclude_brightened=True),
                                           batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = torch.utils.data.DataLoader(DDSMDataset(data_dir, dataset="val", exclude_brightened=True),
                                         batch_size=batch_size, shuffle=True, num_workers=1)

## Load model

In [5]:
model = MyResNet("resnet18", 3, only_train_heads=train_heads)
if checkpoint != "":
    state_dict = torch.load(checkpoint) if torch.cuda.is_available() else torch.load(checkpoint, map_location='cpu')
    model.load_state_dict(state_dict)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

## Run training and validation

In [6]:
if not os.path.isdir(model_res_dir):
    print(model_res_dir + " not found: making directory for results")
    os.mkdir(model_res_dir)
for epoch in range(1, epochs + 1):
    ddsm_classify.train(model, train_loader, optimizer, device, epoch, log_interval)
    ddsm_classify.validation(model, val_loader, device)
    model_file = os.path.join(model_res_dir, "model_stage" +
                              str(1 if train_heads else 2) + "_" + str(epoch) + ".pth")
    torch.save(model.state_dict(), model_file)
    print("\nSaved model to " + model_file + ".")

Training epoch 1...

Validation set: Average loss: 1.1083, Accuracy: 321/738 (43%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_1.pth.
Training epoch 2...

Validation set: Average loss: 1.0136, Accuracy: 337/738 (45%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_2.pth.
Training epoch 3...



Validation set: Average loss: 0.9873, Accuracy: 334/738 (45%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_3.pth.
Training epoch 4...

Validation set: Average loss: 0.9951, Accuracy: 335/738 (45%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_4.pth.
Training epoch 5...



Validation set: Average loss: 1.0017, Accuracy: 292/738 (39%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_5.pth.
Training epoch 6...

Validation set: Average loss: 18.4993, Accuracy: 262/738 (35%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_6.pth.
Training epoch 7...

Validation set: Average loss: 0.9939, Accuracy: 352/738 (47%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_7.pth.
Training epoch 8...



Validation set: Average loss: 0.9927, Accuracy: 351/738 (47%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_8.pth.
Training epoch 9...

Validation set: Average loss: 0.9872, Accuracy: 351/738 (47%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_9.pth.
Training epoch 10...



Validation set: Average loss: 0.9891, Accuracy: 351/738 (47%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_10.pth.
Training epoch 11...

Validation set: Average loss: 0.9925, Accuracy: 354/738 (47%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_11.pth.
Training epoch 12...

Validation set: Average loss: 0.9907, Accuracy: 351/738 (47%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_12.pth.
Training epoch 13...



Validation set: Average loss: 0.9892, Accuracy: 351/738 (47%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_13.pth.
Training epoch 14...

Validation set: Average loss: 0.9907, Accuracy: 351/738 (47%)


Saved model to /scratch/yzs208/CV_Project/model_results_s2_nb/model_stage2_14.pth.
Training epoch 15...


KeyboardInterrupt: 