In [None]:
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch import optim
from unet_ppm import *
from dice_loss import *

In [0]:
sei_patch = np.load('sei_patch.npy')
lab_patch = np.load('lab_patch.npy')

In [0]:
class DataGenerator(Dataset):
    def __init__(self, x_set, y_set):
        self.x, self.y = x_set, y_set

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, index):
        batch_x = self.x[index]
        batch_y = self.y[index]
        return np.expand_dims(batch_x,axis=0), batch_y

In [0]:
def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()

def train(model,optimizer,dataload,num_epochs,device):
    acc_history  = []
    loss_history = []
    miou_history = []
    for epoch in range(num_epochs):
        print('Starting epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)
        since = time.time()
        dataset_size = len(dataload.dataset)
        epoch_loss = 0
        epoch_acc  = 0

        for idx,(x, y) in enumerate(dataload):                 
            optimizer.zero_grad()             
            inputs = x.to(device)
            labels = y.to(device)
            outputs = model(inputs)           
            criterion1 = MulticlassDiceLoss() 
            loss1 = criterion1(outputs,labels.long())
            criterion2 = torch.nn.CrossEntropyLoss()
            loss2 = criterion2(outputs,labels.long())
            e=1e-2
            loss = e*loss1+loss2
            acc  = accuracy(outputs,labels)
            loss.backward()                  
            optimizer.step()                  
            
            epoch_loss += loss.item()
            epoch_acc+= acc
            loss_history.append(loss.item())
            acc_history.append(acc)
            if (idx+1)%10==0:
              print("%d/%d,train_loss:%0.3f,accuracy:%0.3f" % (idx+1, dataset_size // dataload.batch_size, loss.item(),acc))

        time_elapsed = time.time() - since     
        all_epoch_loss=epoch_loss/len(dataload)
        all_epoch_acc=epoch_acc/len(dataload)
        print("epoch %d loss:%0.3f accuracy:%0.3f " % (epoch, all_epoch_loss,all_epoch_acc))
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    torch.save(model,"/content/model_0.pth")      
    return model,loss_history,acc_history


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PSPNet().to(device)
#model=model.double()
train_dataset = DataGenerator(x_set=sei_patch,y_set=lab_patch) 
dataloader = DataLoader(train_dataset, batch_size=15, shuffle=True)
#criterion = torch.nn.CrossEntropyLoss()
#criterion = MulticlassDiceLoss()
optimizer = optim.Adam(model.parameters(),lr=1e-4)
num_epochs=10

model_0,loss,acc=train(model,optimizer,dataloader,num_epochs,device)

In [0]:
fig = plt.figure(figsize=(10,6))

plt.plot(acc)
plt.title('Model accuracy',fontsize=20)
plt.ylabel('Accuracy',fontsize=20)
plt.xlabel('Epoch',fontsize=20)
plt.legend(['train', 'test'], loc='center right',fontsize=20)
plt.tick_params(axis='both', which='major', labelsize=18)
plt.tick_params(axis='both', which='minor', labelsize=18)
plt.show()

fig = plt.figure(figsize=(10,6))
plt.plot(loss)
plt.title('Model loss',fontsize=20)
plt.ylabel('Loss',fontsize=20)
plt.xlabel('Epoch',fontsize=20)
plt.legend(['train', 'test'], loc='center right',fontsize=20)
#plt.set_facecolor('none')
plt.tick_params(axis='both', which='major', labelsize=18)
plt.tick_params(axis='both', which='minor', labelsize=18)
plt.show()