<a href="https://colab.research.google.com/github/thomasp05/GIF-7005-Projet/blob/develop/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Pre-configuration

## Mount Google Drive

In [None]:
# Set up colab instance
from google.colab import drive
drive.mount('/content/drive')

## Clone and pull github repository

In [None]:
# Make sure clone at root
%cd /content
!pip3 install pydicom
!git clone https://github.com/thomasp05/gif-705-projet

import os
os.chdir('gif-705-projet')

In [None]:
!git pull

# Imports and load

In [8]:
import time

import torch

from dataset import *
from models import *

torch.manual_seed(111)

<torch._C.Generator at 0x7fde8e9af630>

## HYPERPARAMETERS

In [None]:
N_EPOCH = 50
BATCH_SIZE = 32

## Load dataset

In [13]:
dataset = dcm_dataset('../drive/MyDrive/GIF-7005-Projet/gif-7005-projet/data')

train_set, test_set = train_test_split(dataset)

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=BATCH_SIZE, num_workers=2)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=BATCH_SIZE, num_workers=2)

# Train

In [9]:
cnn = Vanilla_CNN(width=16).to("cuda:0")
regressor = Simple_regressor(16 * 8 * 8).to("cuda:0")

In [10]:
# If model exists, load it
if os.path.exists("../drive/MyDrive/GIF-7005-Projet/gif-7005-projet/cnn.pt"):
  cnn.load_state_dict(torch.load("../drive/MyDrive/GIF-7005-Projet/gif-7005-projet/cnn.pt"))
if os.path.exists("../drive/MyDrive/GIF-7005-Projet/gif-7005-projet/regressor.pt"):
  regressor.load_state_dict(torch.load("../drive/MyDrive/GIF-7005-Projet/gif-7005-projet/regressor.pt"))

In [11]:
params = list(cnn.parameters()) + list(regressor.parameters())
optim = torch.optim.Adam(params, lr=1e-4)

loss_fn = torch.nn.BCEWithLogitsLoss()

In [None]:
cnn.train()
regressor.train()

for epoch in range(N_EPOCH):
    timer = time.time()

    img_processed = 0

    for img, (target, bounding_box) in train_loader:

        optim.zero_grad()

        img, target = img.to("cuda:0"), target.to("cuda:0")

        pred1 = cnn(img)
        pred2 = regressor(pred1.flatten(1, -1)).squeeze()

        loss = loss_fn(pred2, target.to(torch.float))

        loss.backward()

        optim.step()

        img_processed += img.shape[0]

        # print("{:.2f} %, time : {:.2f}\r".format(100*img_processed/len(train_loader.dataset), time.time()-timer))

    print("Epoch : {}".format(epoch + 1))
    print("Time elapsed : {:.2f}".format(time.time() - timer))

In [None]:
# Save models
torch.save(cnn.state_dict(), "../drive/MyDrive/GIF-7005-Projet/gif-7005-projet/cnn.pt")
torch.save(regressor.state_dict(), "../drive/MyDrive/GIF-7005-Projet/gif-7005-projet/regressor.pt")

In [None]:
# Free memory
del optim, img, target, bounding_box
torch.cuda.empty_cache()

In [None]:
# Accuracy

cnn.eval()
regressor.eval()

score = []

for img, (target, bounding_box) in test_loader:

    img, target = img.to("cuda:0"), target.to("cuda:0")

    pred1 = cnn(img)
    pred2 = regressor(pred1.flatten(1, -1)).squeeze()

    score_ = ((pred2 > 0.5) == (target == 1)).sum()

    score.append(score_)

score = torch.Tensor(score).sum() / len(test_loader.dataset)
print(score)