Runtime > Change Runtime Type > Select GPU > Save

Then run this cell to make pytorch use the GPU

In [1]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(torch.cuda.is_available())
print('device count', torch.cuda.device_count())
print('current', torch.cuda.current_device())
print('GPU', torch.cuda.get_device_name(0))


True
device count 1
current 0
GPU Tesla T4


Zip the data folder and upload as data.zip, then run this cell to unzip it

In [None]:
!unzip data.zip

Run the next 4 cells to start training

In [3]:
from matplotlib import pyplot as plt
def show_image_mask(img, mask, cmap='gray'): # visualisation
    fig = plt.figure(figsize=(5,5))
    plt.subplot(1, 2, 1)
    plt.imshow(img, cmap=cmap)
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.imshow(mask, cmap=cmap)
    plt.axis('off')
    plt.show() # draw the images immediatelly

In [4]:
import torch
import torch.utils.data as data
import cv2
import os
from glob import glob

class TrainDataset(data.Dataset):
    def __init__(self, root=''):
        super(TrainDataset, self).__init__()
        self.img_files = glob(os.path.join(root,'image','*.png'))
        #self.img_files = self.img_files[0:10] # only using part of the dataset
        self.mask_files = []
        for img_path in self.img_files:
            basename = os.path.basename(img_path)
            self.mask_files.append(os.path.join(root,'mask',basename[:-4]+'_mask.png'))
            

    def __getitem__(self, index):
            img_path = self.img_files[index]
            mask_path = self.mask_files[index]
            data = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
            label = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
            return torch.from_numpy(data).float(), torch.from_numpy(label).float()

    def __len__(self):
        return len(self.img_files)

class TestDataset(data.Dataset):
    def __init__(self, root=''):
        super(TestDataset, self).__init__()
        self.img_files = glob(os.path.join(root,'image','*.png'))

    def __getitem__(self, index):
            img_path = self.img_files[index]
            data = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
            return torch.from_numpy(data).float()

    def __len__(self):
        return len(self.img_files)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SimpleHybrid(nn.Module): # Define your model
    def __init__(self):
        super(SimpleHybrid, self).__init__()

        self.pool = nn.MaxPool2d(2, 2)

        self.conv1 = nn.Conv2d(  1, 128, 3, padding=1)
        self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
        
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)

        self.conv5 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv6 = nn.Conv2d(512, 512, 3, padding=1)
        
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners = False)
        
        self.conv7 = nn.Conv2d(512, 256, 3, padding=1)
        self.conv8 = nn.Conv2d(256, 256, 3, padding=1)

        self.conv9  = nn.Conv2d(256, 128, 3, padding=1)
        self.conv10 = nn.Conv2d(128, 128, 3, padding=1)
        
        self.convscores = nn.Conv2d(128, 4, 3, padding=1)
        
    def forward(self, x):
        # fill in the forward function for your model here
        
        x = F.selu(self.conv1(x))
        x = F.selu(self.conv2(x))
        x = F.selu(self.conv2(x))
        x = F.selu(self.conv2(x))
        x = F.selu(self.conv2(x))
        
        x = self.pool(x) # now at 48x48 resolution
        
        x = F.selu(self.conv3(x))
        x = F.selu(self.conv4(x))
        x = F.selu(self.conv4(x))
        x = F.selu(self.conv4(x))
        x = F.selu(self.conv4(x))

        x = self.pool(x)
        
        x = F.selu(self.conv5(x))
        x = F.selu(self.conv6(x))
        x = F.selu(self.conv6(x))
        x = F.selu(self.conv6(x))
        x = F.selu(self.conv6(x))

        #if using ConvTranspose2d, then x = self.upsample(x, output_size=(96,96))
        x = self.upsample(x)
        
        x = F.selu(self.conv7(x))
        x = F.selu(self.conv8(x))
        x = F.selu(self.conv8(x))
        x = F.selu(self.conv8(x))
        x = F.selu(self.conv8(x))

        x = self.upsample(x)
        
        x = F.selu(self.conv9(x))
        x = F.selu(self.conv10(x))
        x = F.selu(self.conv10(x))
        x = F.selu(self.conv10(x))
        x = F.selu(self.conv10(x))
        
        x = self.convscores(x) # CrossEntropyLoss does softmax inside it
        
        return x

model = SimpleHybrid() # We can now create a model using your defined segmentation model
model.to(device)

import torch.optim as optim

Loss = nn.CrossEntropyLoss()
#optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.8, 0.9), eps=1e-08, weight_decay=0)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.8)

In [None]:
from torch.utils.data import DataLoader

data_path = './data/train'
num_workers = 4
batch_size = 4
train_set = TrainDataset(data_path)
training_data_loader = DataLoader(dataset=train_set, num_workers=num_workers, batch_size=batch_size, shuffle=True)

save_network = True

model.train() # switch model to training mode

for epoch in range(1000):
    
    running_loss = 0.0
    
    # Fetch images and labels.  
    for iteration, sample in enumerate(training_data_loader):
        # img, mask = sample 
        img, mask = sample[0].to(device), sample[1].to(device)

        #visualise only the first image in a batch
        #show_image_mask(img[0,...].squeeze(), mask[0,...].squeeze())

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        img = img.unsqueeze(1) # model expects data in precise format
        outputs = model(img)
        
        # print(img.size(), '...', mask.size())
        # print(outputs.size(), pred_class.size(), outputs.dtype, pred_class.dtype, mask.dtype)
        
        #mask = mask.type(torch.LongTensor)
        mask = mask.type(torch.cuda.LongTensor)
        
        #print('...>', mask[0,48,48], '...', pred_class[0,0,48,48], '...', outputs[0,:,48,48])
        
        # loss + backward + optimize
        loss = Loss(outputs, mask)
        loss.backward()
        optimizer.step()

        # print stats
        running_loss += loss.item()
        if iteration == 0:
            print('epoch ', epoch, ' iteration ', iteration, ' loss ', running_loss / 24)
            running_loss = 0.0
            
            # draw predicted images next to ground truth masks
            pred_class = torch.argmax(outputs, dim=1)
            pred_class = pred_class.unsqueeze(1)
            show_image_mask(pred_class[0,...].squeeze().cpu(), mask[0,...].squeeze().cpu())
            
        if save_network:
            if epoch % 50 == 49:
                PATH = './trained_{0}e.pth'.format(epoch)
                torch.save(model.state_dict(), PATH)
            
print('Done training!')