In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
import sys
import numpy as np
import random

In [3]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [4]:
import torchvision
import torchvision.transforms as transforms

In [5]:
from torch.utils.tensorboard import SummaryWriter

In [6]:
import warnings
warnings.filterwarnings(action='ignore')

## Data

In [7]:
path =  '../Preprocessing/Data/'

img_data = np.load(path + 'img_data.npy')
img_label = np.load(path + 'img_age.npy')

In [8]:
data_size = len(img_data)
rand_idx = np.random.choice(data_size , data_size , replace = False)

train_size = int(0.8*data_size)

In [9]:
img_data = img_data[rand_idx]
img_label = img_label[rand_idx]

In [10]:
img_train = img_data[:train_size]
label_train = img_label[:train_size]

In [11]:
img_test = img_data[train_size:]
label_test = img_label[train_size:]

In [12]:
print('Train Data Shape\n')

print(img_train.shape)
print(label_train.shape)

Train Data Shape

(2160, 7, 512, 384, 3)
(2160, 7, 1)


In [13]:
print('Test Data Shape\n')

print(img_test.shape)
print(label_test.shape)

Test Data Shape

(540, 7, 512, 384, 3)
(540, 7, 1)


## Reshape

In [14]:
height = 512
width = 384
channel = 3
a_class = 3

In [15]:
img_train = img_train.reshape(-1 , height , width , channel)
label_train = label_train.reshape(-1,)

img_test = img_test.reshape(-1 , height , width , channel)
label_test = label_test.reshape(-1,)

In [16]:
img_size = 256
re_size = 224
batch_size = 32

mid_h = int(height/2)
mid_w = int(width/2)
mid_s = int(img_size/2)

# crop image of center : image shape (18900 , 256 , 256 , 3)
img_train = img_train[: , 
                      mid_h - mid_s : mid_h + mid_s , 
                      mid_w - mid_s : mid_w + mid_s ,
                      :]

img_test = img_test[: , 
                    mid_h - mid_s : mid_h + mid_s , 
                    mid_w - mid_s : mid_w + mid_s ,
                    :]

## Data Loader

In [17]:
class ProjectDataset(Dataset) :

    def __init__(self , data , label , class_size) :

        super(Dataset , self).__init__()

        self.data = np.transpose(data , (0,3,1,2)) # change channel first (for pytorch)
        self.label = np.eye(class_size)[label.astype('int32')] # one hot vector

    def __len__(self) :

        data_len = self.data.shape[0]

        return data_len

    def __getitem__(self , idx) :

        data_idx = self.data[idx]
        label_idx = self.label[idx]
        
        # return dict type
        sample_idx = {'image' : data_idx , 'label' : label_idx}
        
        return sample_idx

In [18]:
# train dataset
image_train_dset = ProjectDataset(img_train, label_train , a_class)

# train data loader
image_train_loader = DataLoader(image_train_dset, 
                                batch_size=batch_size,
                                shuffle=True ,
                                num_workers=4 ,
                                drop_last=True)

# test dataset
image_test_dset = ProjectDataset(img_test, label_test , a_class)

# test data loader
image_test_loader = DataLoader(image_test_dset,  
                               batch_size=batch_size,
                               shuffle=False , 
                               num_workers=4 ,
                               drop_last=True)

In [19]:
class CutMix :
    
    def __init__(self, img_height , img_width):
        
        self.h = img_height
        self.w = img_width 
        
        # combination rate sampled from the uniform distribution
        self.gen = torch.distributions.beta.Beta(1,1)
        
    def __call__(self, sample):
        
        # org image
        a_image, a_label = sample['image'], sample['label']
        
        batch_size = len(a_image)
        rand = torch.randperm(batch_size)
        
        # image which we will take patch
        b_image = a_image[rand]
        b_label = a_label[rand]
        
        # y , x point of patch
        y = torch.randint(self.h , (1,))[0]
        x = torch.randint(self.w, (1,))[0]

        # combination ratio
        r = self.gen.sample()
        
        # height and width of patch
        h = (self.h * torch.sqrt(1-r)).int()
        w = (self.w * torch.sqrt(1-r)).int()

        # org image(a) + image patch(b)
        c_image = a_image.clone()
        c_image[: , : , y:y+h , x:x+w] = b_image[: , : ,y:y+h , x:x+w]
        
        # combine label
        c_label = a_label * r + b_label * (1-r)
        
        return {'image' : c_image , 'label' : c_label}

## Device

In [20]:
USE_CUDA = torch.cuda.is_available()
random.seed(20210905)
torch.cuda.manual_seed_all(20210905)

device = torch.device("cuda" if USE_CUDA else "cpu") 

## Model

In [21]:
# resnet 101 layer
layer_dim = [3,4,23,3]
ch_dim = [64,128,256,512]

start_ch = 64
start_k = 7
inter_k = 3

In [22]:
class StartBlock(nn.Module) :
    
    def __init__(self , start_ch , start_k) :
        
        super(StartBlock , self).__init__()
        
        self.net = nn.Sequential(nn.Conv2d(3,start_ch,start_k,stride=2,padding=int(start_k/2)),
                                 nn.BatchNorm2d(start_ch),
                                 nn.ReLU(),
                                 nn.MaxPool2d(3 , stride=2 , padding=1))
        
    def forward(self , in_tensor) :
        
        o_tensor = self.net(in_tensor)
        
        return o_tensor
        

In [35]:
class ConvBlock(nn.Module) :
    
    def __init__(self, layer , in_ch , conv_ch , kernal) :
        
        super(ConvBlock , self).__init__()
        
        self.layer = layer
        self.img_size = img_size
        self.in_ch = in_ch
        self.conv_ch = conv_ch
        self.out_ch = conv_ch * 4
        self.kernal = kernal

        self.conv_net = nn.ModuleList()
        
        if self.in_ch != self.out_ch :
            self.flag_11 = True 
            self.conv_11 = nn.Conv2d(in_ch,self.out_ch,1)
        else :
            self.flag_11 = False
        
        for i in range(layer) :
            
            ch_ptr = in_ch if i == 0 else self.out_ch
            
            conv_seq = nn.Sequential(nn.Conv2d(ch_ptr,conv_ch,1), 
                                     nn.BatchNorm2d(conv_ch),
                                     nn.ReLU(),
                                     nn.Conv2d(conv_ch,conv_ch,self.kernal,padding=int(kernal/2)),
                                     nn.BatchNorm2d(conv_ch),
                                     nn.ReLU(),
                                     nn.Conv2d(conv_ch,self.out_ch,1),
                                     nn.BatchNorm2d(self.out_ch),
                                     nn.ReLU())
            self.conv_net.append(conv_seq)
        
    def forward(self, in_tensor) :
        
        tensor_ptr = in_tensor
        p_tensor = self.conv_11(in_tensor) if self.flag_11 else in_tensor
        
        for i in range(self.layer) :
            
            h_tensor = self.conv_net[i](tensor_ptr)    
            o_tensor = F.relu(h_tensor + p_tensor)
            
            tensor_ptr = o_tensor
            p_tenosr = o_tensor
            
        return tensor_ptr
        

In [44]:
class ResNet(nn.Module) :

    def __init__(self, img_size , layers , channels , kernal , 
                 in_ch , in_kernal , class_size) :

        super(ResNet , self).__init__()

        self.img_size = img_size
        self.layers = layers
        self.channels = channels
        self.kernal = kernal
        self.in_ch = in_ch
        self.in_k = in_kernal
        self.class_size = class_size
        
        self.start = StartBlock(in_ch , in_kernal)
        self.resnet = nn.ModuleList()
        self.tran = nn.ModuleList()
        self.norm = nn.ModuleList()

        size_ptr = img_size / 4
        ch_ptr = in_ch
        
        for i in range(len(layers)) :

            self.resnet.append(ConvBlock(layers[i] , ch_ptr , channels[i] , kernal))
            ch_ptr = channels[i] * 4

            if i < len(layers) - 1 :
                self.tran.append(nn.Conv2d(ch_ptr,channels[i+1]*4,1,stride=2))
                self.norm.append(nn.BatchNorm2d(channels[i+1]*4))
                
                ch_ptr = channels[i+1] * 4
                size_ptr /= 2

        self.avg_pool = nn.AvgPool2d(int(size_ptr)) # final average pooling layer
        self.o_layer = nn.Linear(channels[-1]*4 , class_size)

        self.init_param()

    def init_param(self) :

        nn.init.kaiming_normal_(self.o_layer.weight)
        nn.init.zeros_(self.o_layer.bias) 
        
        for m in self.modules() :
            if isinstance(m , nn.Conv2d) :

                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self , in_tensor) :

        batch_size = in_tensor.shape[0]
        tensor_ptr = self.start(in_tensor)
        
        for i in range(len(self.resnet)) :
    
            # conv2 block
            tensor_ptr = self.resnet[i](tensor_ptr)

            if i < len(self.tran) :
                # transition
                tensor_ptr = self.tran[i](tensor_ptr)

                # batch normalization and Relu activation
                tensor_ptr = self.norm[i](tensor_ptr)
                tensor_ptr = F.relu(tensor_ptr)

        avg_tensor = self.avg_pool(tensor_ptr)
        avg_tensor = torch.reshape(avg_tensor , (batch_size , self.channels[-1]*4))

        o_tensor = self.o_layer(avg_tensor)
    
        return o_tensor

In [45]:
epoch_size = 100
min_loss = 1e+7
init_lr = 3e-4
early_count = 0
log_count = 0

cutmix =  CutMix(img_size , img_size)

scalor = transforms.Compose([transforms.Resize((re_size,re_size)),
                             transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2))])

resnet = ResNet(re_size , layer_dim , ch_dim , inter_k , start_ch , start_k , a_class).to(device)

optimizer = optim.Adam(resnet.parameters() , lr = init_lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 3 , gamma = 0.8)

### Accuracy Function

In [46]:
def acc_fn(y_output , y_label) :

    # get max argument of output 
    _ , y_out_arg = torch.max(y_output , dim = -1)
    _ , y_label_arg = torch.max(y_label , dim = -1)

    # check if output max argument if same as output label
    y_acc = (y_out_arg == y_label_arg).float()    
    y_acc = torch.mean(y_acc)

    return y_acc

### Loss Function

In [47]:
def loss_fn(y_output , y_label) :
    
    y_prob = F.softmax(y_output , dim = -1)
    y_log = -torch.log(y_prob + 1e-12)
    
    y_loss = torch.multiply(y_log , y_label)
    y_loss = torch.sum(y_loss , dim = -1)
    
    y_loss = torch.mean(y_loss)
    
    return y_loss

## Writer

In [48]:
writer = SummaryWriter('runs/resnet/cutmix/')

## Training

In [49]:
def progressLearning(value, endvalue, loss , acc , bar_length=50):
      
    percent = float(value + 1) / endvalue
    arrow = '-' * int(round(percent * bar_length)-1) + '>'
    spaces = ' ' * (bar_length - len(arrow))

    sys.stdout.write("\rPercent: [{0}] {1}/{2} \t Loss : {3:.3f} , Acc : {4:.3f}".format(arrow + spaces, value+1 , endvalue , loss , acc))
    sys.stdout.flush()

In [50]:
def evaluate(model , scalor  , test_loader , device) :

    loss = 0.0
    acc = 0.0

    with torch.no_grad() :

        model.eval()

        for img_data in test_loader :
     
            img_in , img_label = img_data['image'] , img_data['label']
            img_in = img_in.float().to(device) / 255
            img_in = scalor(img_in) # scalor
            
            img_label = img_label.float().to(device)

            img_output = model(img_in)

            loss_idx = loss_fn(img_output , img_label)
            acc_idx = acc_fn(img_output , img_label)

            loss += loss_idx
            acc += acc_idx

        model.train()

    loss /= len(test_loader)
    acc /= len(test_loader)

    return loss , acc

In [None]:
for epoch in range(epoch_size) :

    idx = 0
    
    print('Epoch : %d \t Learning Rate : %e' %(epoch , optimizer.param_groups[0]['lr'])) 
    
    for img_data in image_train_loader : 

        img_data = cutmix(img_data)
     
        img_in , img_label = img_data['image'] , img_data['label']
        img_in = img_in.float().to(device) / 255
        img_in = scalor(img_in) # scalor
        
        optimizer.zero_grad()
        
        img_label = img_label.float().to(device)
        
        img_output = resnet(img_in) 
        
        loss = loss_fn(img_output , img_label)  
        acc = acc_fn(img_output , img_label) 
        
        loss.backward()
        optimizer.step()

        if (idx + 1) % 10 == 0 :
            
            writer.add_scalar('train/loss' , loss.item() , log_count)
            writer.add_scalar('train/acc' , acc.item() , log_count)
            log_count += 1
        
        progressLearning(idx , len(image_train_loader) , loss, acc) 

        idx += 1 

    test_loss, test_acc = evaluate(resnet, scalor , image_test_loader , device) 
    
    writer.add_scalar('test/loss' , test_loss.item() , epoch)
    writer.add_scalar('test/acc' , test_acc.item() , epoch)
    
    if test_loss < min_loss :
        
        min_loss = test_loss
        torch.save({'epoch' : (epoch) ,  
                    'model_state_dict' : resnet.state_dict() , 
                    'loss' : test_loss.item() , 
                    'acc' : test_acc.item()} , 
                    f'./Model/checkpoint_resnet_cutmix.pt')        
        early_count = 0 
        
    else :
        
        early_count += 1
        if early_count >= 5 :      
            print('\nTraining Stopped')
            break

    scheduler.step()
    print('\nValidation Loss : %.4f \t Validation Acc : %.4f\n' %(test_loss , test_acc))


Epoch : 0 	 Learning Rate : 3.000000e-04
Percent: [------------------------------------------------->] 472/472 	 Loss : 1.018 , Acc : 0.500
Validation Loss : 0.6560 	 Validation Acc : 0.7018

Epoch : 1 	 Learning Rate : 3.000000e-04
Percent: [------------------------------------------------->] 472/472 	 Loss : 0.752 , Acc : 0.594
Validation Loss : 0.6463 	 Validation Acc : 0.7873

Epoch : 2 	 Learning Rate : 3.000000e-04
Percent: [------------------------------------------------->] 472/472 	 Loss : 0.627 , Acc : 0.781
Validation Loss : 0.6535 	 Validation Acc : 0.7823

Epoch : 3 	 Learning Rate : 2.400000e-04
Percent: [------------------------------------------------->] 472/472 	 Loss : 0.733 , Acc : 0.875

In [None]:
!tensorboard --logdir=./runs/resnet/  --port=6006 --bind_all