In [None]:
!pip install torchsummary

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms

from torchsummary import summary

import os, time, sys

## MobileNet v2 with LayerNorm

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Block_LN(nn.Module):
    '''expand + depthwise + pointwise'''
    def __init__(self, in_planes, out_planes, expansion, stride):
        super(Block_LN, self).__init__()
        self.stride = stride

        planes = expansion * in_planes
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.GroupNorm(1, planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
        self.bn2 = nn.GroupNorm(1, planes)
        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.GroupNorm(1, out_planes)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
                nn.GroupNorm(1, out_planes),
            )

    def forward(self, x):
        out = F.relu6(self.bn1(self.conv1(x)))
        out = F.relu6(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out = out + self.shortcut(x) if self.stride==1 else out
#         out = out + x if self.stride==1 else out
        return out


class MobileNetv2LN(nn.Module):
    # (expansion, out_planes, num_blocks, stride)
    cfg = [(1,  16, 1, 1),
           (6,  24, 2, 2), 
           (6,  32, 3, 2),
           (6,  64, 4, 2),
           (6,  96, 3, 1),
           (6, 160, 3, 2),
           (6, 320, 1, 1)]

    def __init__(self, num_classes=3):
        super(MobileNetv2LN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.GroupNorm(1, 32)
        self.layers = self._make_layers(in_planes=32)
        self.conv2 = nn.Conv2d(320, 160, kernel_size=1, stride=1, padding=0, bias=False)
        # NOTE: output channel of origin MobileNet v2 : 1280 // Ours : 160
        self.bn2 = nn.GroupNorm(1, 160)
        self.linear = nn.Linear(160, num_classes)

    def _make_layers(self, in_planes):
        layers = []
        for expansion, out_planes, num_blocks, stride in self.cfg:
            strides = [stride] + [1]*(num_blocks-1)
            for stride in strides:
                layers.append(Block_LN(in_planes, out_planes, expansion, stride))
                in_planes = out_planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu6(self.bn1(self.conv1(x)))
        out = self.layers(out)
        out = F.relu6(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out



In [None]:
## util

# _, term_width = os.popen('stty size', 'r').read().split()
term_width = 80

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f

In [None]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
best_loss = 1000
best_acc_idx = 0
best_loss_idx = 0
start_epoch = 0  # start from epoch 0 or last checkpoint epoch



In [None]:
learning_rate = 0.001
cosine_tmax=200
batch_size=128
epochs=200
dropout=0.3
architecture='mobilenetv2_ln'

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=(0.2, 2), 
                            contrast=(0.3, 2), 
                            saturation=(0.2, 2), 
                            hue=(-0.3, 0.3)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])


In [None]:
path = "/kaggle/input/large-covid19-ct-slice-dataset/curated_data/curated_data/"

In [None]:
image_datasets = datasets.ImageFolder(path, data_transforms)

classes = image_datasets.classes
train_size = int(0.8*len(image_datasets))
test_size = len(image_datasets)-train_size

train_dataset, test_dataset = torch.utils.data.random_split(
    image_datasets, [train_size, test_size])

trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True, num_workers=0)


In [None]:
net = MobileNetv2LN()
net = net.to(device)
print(summary(net,(3,224,224)))

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_tmax)

# Creates a GradScaler once at the beginning of training.
scaler = torch.cuda.amp.GradScaler()

In [None]:
# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)

    net.train()
    train_loss = 0
    correct = 0
    total = 0
    closs=0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
#         loss.backward()
        scaler.step(optimizer)
#         optimizer.step()
        # Updates the scale for next iteration.
        scaler.update()
        
        closs=closs+loss.item()
#         wandb.log({"batch loss":loss.item()})

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
#     wandb.log({"loss":closs/config.batch_size})

def test(epoch):
    global best_acc, best_loss, best_acc_idx, best_loss_idx
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    example_images=[]
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
#             example_images.append(wandb.Image(
#                 inputs[0], caption="Pred: {} Truth: {}".format(predicted[0].item(), targets[0])))
            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
    print({
#         "Examples": example_images,
        "Test Accuracy": 100. * correct / len(testloader.dataset),
        "Test Loss": test_loss
    })

    # Save checkpoint.
    acc = 100. * correct / total
    if acc > best_acc:
        print('Saving best accuracy model..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc
        best_acc_idx = epoch

    if test_loss < best_loss:
        print('Saving best loss model..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'loss': test_loss,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt_loss.pth')
        best_loss = test_loss
        best_loss_idx = epoch
        
    print({
        "best_acc": best_acc,
        "best_acc_idx": best_acc_idx,
        "best_loss": best_loss,
        "best_loss_idx" : best_loss_idx
    })

In [None]:
for epoch in range(start_epoch, start_epoch + epochs):
    train(epoch)
    test(epoch)
    scheduler.step()

Because training takes too long, we evaluate the performance of the best model ended at 47 epochs.

In [None]:
test(47)