In [1]:
import torch
import torchvision
from PIL import Image
import torchvision.transforms as T
import torch.nn as nn
import math
import torchvision.transforms.functional as TF
from torchvision.utils import save_image

def double_conv(in_channels, out_channels):
            #def conv to be applied

    
    conv = nn.Sequential(
            
            nn.Conv2d(in_channels, out_channels, 3 ,1 ,1 , bias = False), 
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace = True),
            
            nn.Conv2d(out_channels, out_channels, 3 ,1 ,1 , bias = False),
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace = True)
            
                            )

    return conv

In [2]:
'''
Think about the sizes of the feature maps. 500x500 downsamples to 250x250, okay. 
That downsamples to 125x125, okay. But what do you get when you downsample 125x125 
by factor 2 (as the max pooling layers do)? Can you have feature maps with fractions
of sizes (no)? Accordingly, this will result in feature maps of size 62x62 (basically 
cutting off one pixel in the input), which will later upsample back to 124x124 and create
a mismatch.
'''


'\nThink about the sizes of the feature maps. 500x500 downsamples to 250x250, okay. \nThat downsamples to 125x125, okay. But what do you get when you downsample 125x125 \nby factor 2 (as the max pooling layers do)? Can you have feature maps with fractions\nof sizes (no)? Accordingly, this will result in feature maps of size 62x62 (basically \ncutting off one pixel in the input), which will later upsample back to 124x124 and create\na mismatch.\n'

In [14]:
def crop_sc(tensor, target_tensor): #crop the skip connection to size of upsample target tensor-> upsample, tensor-> skip connection
    #target_size = target_tensor.size()[2] #batch size, channels, height, width
    #tensor_size = tensor.size()[2]
    #delta = (tensor_size - target_size)//2 #64 - 56
    
    tensor = TF.resize(tensor, size=target_tensor.shape[2:]) 
    #print('tensor_size ',tensor_size,'target_size; ',target_size,'delta:',delta)

    return tensor
    #tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta] #4:60 , 4:60

def double_conv(in_ch, out_ch):
        #def conv to be applied
    conv = nn.Sequential( #apply 2 conv sequantially
        nn.Conv2d(in_ch, out_ch, 3),
        nn.ReLU(inplace = True), 
            
        nn.Conv2d(out_ch, out_ch, 3),
        nn.ReLU(inplace = True)
    )
    return conv

class UNet(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 3, features = [64,128,256,512,1024]):
        
        super(UNet, self).__init__()
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        #bunlar module ile liste icinde tutulabilir 
        
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        
        for feature in features:
            #self.d_conv = double_conv(in_channels,feature)
            print(in_channels,feature)
            self.downs.append(double_conv(in_channels,feature))
            in_channels = feature
            
            
        for feature in reversed(features[1:]):

            print('transpose: ', feature, feature//2)
            self.ups.append(nn.ConvTranspose2d(feature,feature//2,2,2))
            self.ups.append(double_conv(feature,feature//2))
            
        print('bottleneck', features[-1]//2,features[-1])
        
        self.bottleneck = double_conv(features[-1]//2,features[-1]) #512, 1024
        self.out = nn.Conv2d(features[0],2,3) #64,2,1

    def forward(self, x):
        skip_connections = []  
        print("x",x.size())
        for down in self.downs[:-1]:
            x = down(x)
            print("before max pool",x.size())
            skip_connections.append(x) #we will use this in concatination with up parts
            x = self.max_pool(x)
            
            
        x = self.bottleneck(x) #apply the doubleconv at the bottleneck step, bottleneck ile basla
        print("convolved bottleneck: ",x.size())
        
        skip_connections = skip_connections[::-1] #to reverse the list
        
        for idx in range(0, len(self.ups) , 2):
            print(idx)
            print(self.ups[idx])
            x = self.ups[idx](x) # apply transpose conv, start with bottleneck 
            s_con = skip_connections[idx//2] #bc idx = 0,2,4,6,8.. get skip connection, skip connection is the last layer form contracting path
            print("sizes", x.size(),s_con.size())
            #s_con is 64x64x512, 136x136x256, 
            y = crop_sc(s_con,x) #crop the skip connectio 64 needs to be resized to 56
            
                #concat
            concat = torch.cat([x,y] ,1) #concat bottleneck 512x56x56 with cropped 64 from contracting layer (56x56x512)
            ##element wise addition output is 56x56x512
            print('doubleconv id',idx+1)
            x = self.ups[idx+1](concat) #apply double conv
            #double convs are stored in indexes 1,3,5
            print('first layer done out should be 56x56x1024, now apply double conv output should be 52x52x512 cunku topluyoruy birbirne eklemiyoruy asko', x.size())
        print('hamd')                    
        return self.out(x)


In [15]:
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image

#transform = T.ToPILImage()

In [16]:
#model gets as input a single sample -sample image converted to torch tensor- outputs another image 

if __name__ == "__main__":
    image = torch.rand((3,3,572,572))
    model = UNet()
    print(model(image))
    output = model(image)
    model.forward


3 64
64 128
128 256
256 512
512 1024
transpose:  1024 512
transpose:  512 256
transpose:  256 128
transpose:  128 64
bottleneck 512 1024
x torch.Size([3, 3, 572, 572])
before max pool torch.Size([3, 64, 568, 568])
before max pool torch.Size([3, 128, 280, 280])
before max pool torch.Size([3, 256, 136, 136])
before max pool torch.Size([3, 512, 64, 64])
convolved bottleneck:  torch.Size([3, 1024, 28, 28])
0
ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
sizes torch.Size([3, 512, 56, 56]) torch.Size([3, 512, 64, 64])
doubleconv id 1
first layer done out should be 56x56x1024, now apply double conv output should be 52x52x512 cunku topluyoruy birbirne eklemiyoruy asko torch.Size([3, 512, 52, 52])
2
ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
sizes torch.Size([3, 256, 104, 104]) torch.Size([3, 256, 136, 136])
doubleconv id 3
first layer done out should be 56x56x1024, now apply double conv output should be 52x52x512 cunku topluyoruy birbirne eklemiyoruy asko torc

    '''
    image = Image.open(r"C:\Users\derbent.z\Desktop\images\All Images\new_data\img_0202\20220222_101933._0.bmp")
    image.show()
    convert_tensor = transforms.ToTensor()
    image = convert_tensor(image)
    image = image.unsqueeze(0)
    print(image)
    print(image.shape)
    '''
        #transform = T.ToPILImage()
    #img = transform(output)
    '''im = output[0]
    save_image(im, "img1.png")
    '''
    #im.show()

## Data Set 


In [12]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from imageio import imread
import os

class DT_DataSet(Dataset):
    #when an object of class DT_DataSet is created the attributes given would be masks and images folder names
    def __init__(self,masks_path,images_path,transform = None):
        self.masks_path = masks_path
        self.images_path = images_path
        self.transform = transform
        
        #load the whole dataset, holds name of each single image in the dataset
        self.masks = sorted(os.listdir(self.masks_path))
        self.images = sorted(os.listdir(self.images_path))
    
    def __len__(self):
        return len(self.masks)
        
    def __getitem__(self,index):
        #loads single data sample in the dataset, dataloader uses __getitem__ to get a single pair of (image,mask)
        #combines multiple of these together to form a batch
    
        #combine image name with the folder name to create a path for the image(onesingle image in the dataset not the whole set)
        mask_filename,image_filename = os.path.join(self.masks_path, self.masks[index]),os.path.join(self.images_path, self.images[index])
        
        #read the image as an array
        mask, image = imread(mask_filename), imread(image_filename)
     
        if self.transform != None:
            mask, image = self.transform(mask),self.transform(image)
            
        return mask, image
#give folder name iterate through each file in the folder
#dataset object holds the entire dataset
#train data will be augmented images, validation data will be training images
val_mask_path = r"C:\Users\derbent.z\Desktop\Training attempts\training_0302\mask train set"
val_image_path = r"C:\Users\derbent.z\Desktop\Training attempts\training_0302\train set"
dataset = DT_DataSet(val_mask_path, val_image_path,transforms.ToTensor())
#dataloader loads this dataset in batched for the training of the network
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

run training -> epoch size kadar bir iteration olacak, main metod veya run trianingde yapilabilir 

her epoch icin batch size kadar bir iteration olacak  
her batch sizeda tarinset(augmented images) pairler dönülecek

loss epoch sonunda hesaplanacak? yoksa her batchtan sonra mi 

## Train the Network

To train the network:
- **batch size:** 32
- **learning rate:** 0.001
- **epoch size:** 100
- **optimizer:** Adam
- **loss function:** Cross entropy
- **val set:** train images
- **train set:** augmented set

- model: e.g. the U-Net
- device: CPU or GPU
- criterion: loss function (e.g. CrossEntropyLoss, DiceCoefficientLoss)
- optimizer: e.g. SGD
- training_DataLoader: a training dataloader
- validation_DataLoader: a validation dataloader
- lr_scheduler: a learning rate scheduler (optional)
- epochs: The number of epochs we want to train
- epoch: The epoch number from where training should start

1- In train function iterate over the dataloader (holds the whole dataset and splits into batches) 
2- train the model
3- calculate loss by comparin pred outputmask with gt output mask 
4- store the loss's in a list
5- apply backward pass
6- use gradient
7- use optimizer to update the model parameters
7- compare loss's 
8- determine the lowest loss 
9- training over

In validation:
- iterate over val set
- calc loss
- determine the best loss
- no gradient no backprop, we are not learning in validation 

In [None]:
Training the model:
    - Determine the epoch size
    - split training set to batches (of 32)
    - for each sample in the batch in a single epoch (50 batches in a single epoch 32 samples in a single batch):
        run forward pass
        calc loss
        run backprop
        compute gradients respective to the loss func.
        add to accumulated gradients
    - Obtain accumulated gradients at the end of each batch 
    - adjust the parameters at the end of each batch by using the Adam optimizer
    - after each batch in a single epoch is complete
    - run validation set through the network store the loss (after every epoch)
    - Run the remaining batches, if the loss is not decreasing then save the model parameters with the lowest loss
    
    -the input image is going to be a 4d torch tensor input is 3 channels output is 3 channels (later when adding the splice it will be 4 channels)
    - convert the 3d tensor to 4d tensor to use in training the model 

## Training the model 

In [None]:
#epoch size 100
#dataloader splits the dataset into size 32 batches
class Train():
    def __init__(self,model,criterion,backprop,optimizer,train_data, val_data,epoch,batch_size):
        self.model = model
        self.criterion = criterion
        self.backprop = backprop
        self.optimizer = optimizer
        self.train_data
        self.val_data
        

def train():
    #model() my model gets as input a single image as a torch tensor
    for image, mask, gt in batch:
        out = self.model(image,mask,gt)
        loss = self.criterion(out, gt)
        self.backprop(loss)
        self.optimizer()
        
        
def run_train():
    
    for i in range(epoch_size):
        train()
        run on validation 
        val_loss.append(loss)
    

In [None]:
sample code:

In [None]:

from torch import nn
import torch


@torch.jit.script
def autocrop(encoder_layer: torch.Tensor, decoder_layer: torch.Tensor):
    """
    Center-crops the encoder_layer to the size of the decoder_layer,
    so that merging (concatenation) between levels/blocks is possible.
    This is only necessary for input sizes != 2**n for 'same' padding and always required for 'valid' padding.
    """
    if encoder_layer.shape[2:] != decoder_layer.shape[2:]:
        ds = encoder_layer.shape[2:]
        es = decoder_layer.shape[2:]
        assert ds[0] >= es[0]
        assert ds[1] >= es[1]
        if encoder_layer.dim() == 4:  # 2D
            encoder_layer = encoder_layer[
                            :,
                            :,
                            ((ds[0] - es[0]) // 2):((ds[0] + es[0]) // 2),
                            ((ds[1] - es[1]) // 2):((ds[1] + es[1]) // 2)
                            ]
        elif encoder_layer.dim() == 5:  # 3D
            assert ds[2] >= es[2]
            encoder_layer = encoder_layer[
                            :,
                            :,
                            ((ds[0] - es[0]) // 2):((ds[0] + es[0]) // 2),
                            ((ds[1] - es[1]) // 2):((ds[1] + es[1]) // 2),
                            ((ds[2] - es[2]) // 2):((ds[2] + es[2]) // 2),
                            ]
    return encoder_layer, decoder_layer


def conv_layer(dim: int):
    if dim == 3:
        return nn.Conv3d
    elif dim == 2:
        return nn.Conv2d


def get_conv_layer(in_channels: int,
                   out_channels: int,
                   kernel_size: int = 3,
                   stride: int = 1,
                   padding: int = 1,
                   bias: bool = True,
                   dim: int = 2):
    return conv_layer(dim)(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                           bias=bias)


def conv_transpose_layer(dim: int):
    if dim == 3:
        return nn.ConvTranspose3d
    elif dim == 2:
        return nn.ConvTranspose2d


def get_up_layer(in_channels: int,
                 out_channels: int,
                 kernel_size: int = 2,
                 stride: int = 2,
                 dim: int = 3,
                 up_mode: str = 'transposed',
                 ):
    if up_mode == 'transposed':
        return conv_transpose_layer(dim)(in_channels, out_channels, kernel_size=kernel_size, stride=stride)
    else:
        return nn.Upsample(scale_factor=2.0, mode=up_mode)


def maxpool_layer(dim: int):
    if dim == 3:
        return nn.MaxPool3d
    elif dim == 2:
        return nn.MaxPool2d


def get_maxpool_layer(kernel_size: int = 2,
                      stride: int = 2,
                      padding: int = 0,
                      dim: int = 2):
    return maxpool_layer(dim=dim)(kernel_size=kernel_size, stride=stride, padding=padding)


def get_activation(activation: str):
    if activation == 'relu':
        return nn.ReLU()
    elif activation == 'leaky':
        return nn.LeakyReLU(negative_slope=0.1)
    elif activation == 'elu':
        return nn.ELU()


def get_normalization(normalization: str,
                      num_channels: int,
                      dim: int):
    if normalization == 'batch':
        if dim == 3:
            return nn.BatchNorm3d(num_channels)
        elif dim == 2:
            return nn.BatchNorm2d(num_channels)
    elif normalization == 'instance':
        if dim == 3:
            return nn.InstanceNorm3d(num_channels)
        elif dim == 2:
            return nn.InstanceNorm2d(num_channels)
    elif 'group' in normalization:
        num_groups = int(normalization.partition('group')[-1])  # get the group size from string
        return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)


class Concatenate(nn.Module):
    def __init__(self):
        super(Concatenate, self).__init__()

    def forward(self, layer_1, layer_2):
        x = torch.cat((layer_1, layer_2), 1)

        return x


class DownBlock(nn.Module):
    """
    A helper Module that performs 2 Convolutions and 1 MaxPool.
    An activation follows each convolution.
    A normalization layer follows each convolution.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 pooling: bool = True,
                 activation: str = 'relu',
                 normalization: str = None,
                 dim: str = 2,
                 conv_mode: str = 'same'):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling
        self.normalization = normalization
        if conv_mode == 'same':
            self.padding = 1
        elif conv_mode == 'valid':
            self.padding = 0
        self.dim = dim
        self.activation = activation

        # conv layers
        self.conv1 = get_conv_layer(self.in_channels, self.out_channels, kernel_size=3, stride=1, padding=self.padding,
                                    bias=True, dim=self.dim)
        self.conv2 = get_conv_layer(self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=self.padding,
                                    bias=True, dim=self.dim)

        # pooling layer
        if self.pooling:
            self.pool = get_maxpool_layer(kernel_size=2, stride=2, padding=0, dim=self.dim)

        # activation layers
        self.act1 = get_activation(self.activation)
        self.act2 = get_activation(self.activation)

        # normalization layers
        if self.normalization:
            self.norm1 = get_normalization(normalization=self.normalization, num_channels=self.out_channels,
                                           dim=self.dim)
            self.norm2 = get_normalization(normalization=self.normalization, num_channels=self.out_channels,
                                           dim=self.dim)

    def forward(self, x):
        y = self.conv1(x)  # convolution 1
        y = self.act1(y)  # activation 1
        if self.normalization:
            y = self.norm1(y)  # normalization 1
        y = self.conv2(y)  # convolution 2
        y = self.act2(y)  # activation 2
        if self.normalization:
            y = self.norm2(y)  # normalization 2

        before_pooling = y  # save the outputs before the pooling operation
        if self.pooling:
            y = self.pool(y)  # pooling
        return y, before_pooling


class UpBlock(nn.Module):
    """
    A helper Module that performs 2 Convolutions and 1 UpConvolution/Upsample.
    An activation follows each convolution.
    A normalization layer follows each convolution.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 activation: str = 'relu',
                 normalization: str = None,
                 dim: int = 3,
                 conv_mode: str = 'same',
                 up_mode: str = 'transposed'
                 ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalization = normalization
        if conv_mode == 'same':
            self.padding = 1
        elif conv_mode == 'valid':
            self.padding = 0
        self.dim = dim
        self.activation = activation
        self.up_mode = up_mode

        # upconvolution/upsample layer
        self.up = get_up_layer(self.in_channels, self.out_channels, kernel_size=2, stride=2, dim=self.dim,
                               up_mode=self.up_mode)

        # conv layers
        self.conv0 = get_conv_layer(self.in_channels, self.out_channels, kernel_size=1, stride=1, padding=0,
                                    bias=True, dim=self.dim)
        self.conv1 = get_conv_layer(2 * self.out_channels, self.out_channels, kernel_size=3, stride=1,
                                    padding=self.padding,
                                    bias=True, dim=self.dim)
        self.conv2 = get_conv_layer(self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=self.padding,
                                    bias=True, dim=self.dim)

        # activation layers
        self.act0 = get_activation(self.activation)
        self.act1 = get_activation(self.activation)
        self.act2 = get_activation(self.activation)

        # normalization layers
        if self.normalization:
            self.norm0 = get_normalization(normalization=self.normalization, num_channels=self.out_channels,
                                           dim=self.dim)
            self.norm1 = get_normalization(normalization=self.normalization, num_channels=self.out_channels,
                                           dim=self.dim)
            self.norm2 = get_normalization(normalization=self.normalization, num_channels=self.out_channels,
                                           dim=self.dim)

        # concatenate layer
        self.concat = Concatenate()

    def forward(self, encoder_layer, decoder_layer):
        """ Forward pass
        Arguments:
            encoder_layer: Tensor from the encoder pathway
            decoder_layer: Tensor from the decoder pathway (to be up'd)
        """
        up_layer = self.up(decoder_layer)  # up-convolution/up-sampling
        cropped_encoder_layer, dec_layer = autocrop(encoder_layer, up_layer)  # cropping

        if self.up_mode != 'transposed':
            # We need to reduce the channel dimension with a conv layer
            up_layer = self.conv0(up_layer)  # convolution 0
        up_layer = self.act0(up_layer)  # activation 0
        if self.normalization:
            up_layer = self.norm0(up_layer)  # normalization 0

        merged_layer = self.concat(up_layer, cropped_encoder_layer)  # concatenation
        y = self.conv1(merged_layer)  # convolution 1
        y = self.act1(y)  # activation 1
        if self.normalization:
            y = self.norm1(y)  # normalization 1
        y = self.conv2(y)  # convolution 2
        y = self.act2(y)  # acivation 2
        if self.normalization:
            y = self.norm2(y)  # normalization 2
        return y


class UNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 out_channels: int = 2,
                 n_blocks: int = 4,
                 start_filters: int = 32,
                 activation: str = 'relu',
                 normalization: str = 'batch',
                 conv_mode: str = 'same',
                 dim: int = 2,
                 up_mode: str = 'transposed'
                 ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_blocks = n_blocks
        self.start_filters = start_filters
        self.activation = activation
        self.normalization = normalization
        self.conv_mode = conv_mode
        self.dim = dim
        self.up_mode = up_mode

        self.down_blocks = []
        self.up_blocks = []

        # create encoder path
        for i in range(self.n_blocks):
            num_filters_in = self.in_channels if i == 0 else num_filters_out
            num_filters_out = self.start_filters * (2 ** i)
            pooling = True if i < self.n_blocks - 1 else False

            down_block = DownBlock(in_channels=num_filters_in,
                                   out_channels=num_filters_out,
                                   pooling=pooling,
                                   activation=self.activation,
                                   normalization=self.normalization,
                                   conv_mode=self.conv_mode,
                                   dim=self.dim)

            self.down_blocks.append(down_block)

        # create decoder path (requires only n_blocks-1 blocks)
        for i in range(n_blocks - 1):
            num_filters_in = num_filters_out
            num_filters_out = num_filters_in // 2

            up_block = UpBlock(in_channels=num_filters_in,
                               out_channels=num_filters_out,
                               activation=self.activation,
                               normalization=self.normalization,
                               conv_mode=self.conv_mode,
                               dim=self.dim,
                               up_mode=self.up_mode)

            self.up_blocks.append(up_block)

        # final convolution
        self.conv_final = get_conv_layer(num_filters_out, self.out_channels, kernel_size=1, stride=1, padding=0,
                                         bias=True, dim=self.dim)

        # add the list of modules to current module
        self.down_blocks = nn.ModuleList(self.down_blocks)
        self.up_blocks = nn.ModuleList(self.up_blocks)

        # initialize the weights
        self.initialize_parameters()

    @staticmethod
    def weight_init(module, method, **kwargs):
        if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):
            method(module.weight, **kwargs)  # weights

    @staticmethod
    def bias_init(module, method, **kwargs):
        if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):
            method(module.bias, **kwargs)  # bias

    def initialize_parameters(self,
                              method_weights=nn.init.xavier_uniform_,
                              method_bias=nn.init.zeros_,
                              kwargs_weights={},
                              kwargs_bias={}
                              ):
        for module in self.modules():
            self.weight_init(module, method_weights, **kwargs_weights)  # initialize weights
            self.bias_init(module, method_bias, **kwargs_bias)  # initialize bias

    def forward(self, x: torch.tensor):
        encoder_output = []

        # Encoder pathway
        for module in self.down_blocks:
            x, before_pooling = module(x)
            encoder_output.append(before_pooling)

        # Decoder pathway
        for i, module in enumerate(self.up_blocks):
            before_pool = encoder_output[-(i + 2)]
            x = module(before_pool, x)

        x = self.conv_final(x)

        return x

    def __repr__(self):
        attributes = {attr_key: self.__dict__[attr_key] for attr_key in self.__dict__.keys() if '_' not in attr_key[0] and 'training' not in attr_key}
        d = {self.__class__.__name__: attributes}
        return f'{d}'
view rawunet.py hosted with ❤ by GitHub

In [16]:
sample training:

SyntaxError: invalid syntax (3262890820.py, line 1)

In [None]:
import numpy as np
import torch


class Trainer:
    def __init__(self,
                 model: torch.nn.Module,
                 device: torch.device,
                 criterion: torch.nn.Module,
                 optimizer: torch.optim.Optimizer,
                 training_DataLoader: torch.utils.data.Dataset,
                 validation_DataLoader: torch.utils.data.Dataset = None,
                 lr_scheduler: torch.optim.lr_scheduler = None,
                 epochs: int = 100,
                 epoch: int = 0,
                 notebook: bool = False
                 ):

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.training_DataLoader = training_DataLoader
        self.validation_DataLoader = validation_DataLoader
        self.device = device
        self.epochs = epochs
        self.epoch = epoch
        self.notebook = notebook

        self.training_loss = []
        self.validation_loss = []
        self.learning_rate = []

    def run_trainer(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        progressbar = trange(self.epochs, desc='Progress')
        for i in progressbar:
            """Epoch counter"""
            self.epoch += 1  # epoch counter

            """Training block"""
            self._train()

            """Validation block"""
            if self.validation_DataLoader is not None:
                self._validate()

            """Learning rate scheduler block"""
            if self.lr_scheduler is not None:
                if self.validation_DataLoader is not None and self.lr_scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                    self.lr_scheduler.batch(self.validation_loss[i])  # learning rate scheduler step with validation loss
                else:
                    self.lr_scheduler.batch()  # learning rate scheduler step
        return self.training_loss, self.validation_loss, self.learning_rate

    def _train(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.train()  # train mode
        train_losses = []  # accumulate the losses here
        batch_iter = tqdm(enumerate(self.training_DataLoader), 'Training', total=len(self.training_DataLoader),
                          leave=False)

        for i, (x, y) in batch_iter:
            input, target = x.to(self.device), y.to(self.device)  # send to device (GPU or CPU)
            self.optimizer.zero_grad()  # zerograd the parameters
            out = self.model(input)  # one forward pass
            loss = self.criterion(out, target)  # calculate loss
            loss_value = loss.item()
            train_losses.append(loss_value)
            loss.backward()  # one backward pass
            self.optimizer.step()  # update the parameters

            batch_iter.set_description(f'Training: (loss {loss_value:.4f})')  # update progressbar

        self.training_loss.append(np.mean(train_losses))
        self.learning_rate.append(self.optimizer.param_groups[0]['lr'])

        batch_iter.close()

    def _validate(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.eval()  # evaluation mode
        valid_losses = []  # accumulate the losses here
        batch_iter = tqdm(enumerate(self.validation_DataLoader), 'Validation', total=len(self.validation_DataLoader),
                          leave=False)

        for i, (x, y) in batch_iter:
            input, target = x.to(self.device), y.to(self.device)  # send to device (GPU or CPU)

            with torch.no_grad():
                out = self.model(input)
                loss = self.criterion(out, target)
                loss_value = loss.item()
                valid_losses.append(loss_value)

                batch_iter.set_description(f'Validation: (loss {loss_value:.4f})')

        self.validation_loss.append(np.mean(valid_losses))

        batch_iter.close()
view rawtrainer.py hosted with ❤ by GitHub