In [1]:
'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
!pip3 install wandb
!pip3 install torchsummary

#import os
import sys
import time
import math
import shutil 

import torch.nn as nn
import torch.nn.init as init

import wandb
import warnings
import torch

import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import sys
import matplotlib.pyplot as plt
from torchsummary import summary
import numpy as np
wandb.init(project="TinyImagenetModelParallel", name="Resnet50ModelParallel")



Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/share/apps/python/3.8.6/intel/bin/python -m pip install --upgrade pip' command.[0m
Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/share/apps/python/3.8.6/intel/bin/python -m pip install --upgrade pip' command.[0m


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmnk2978[0m ([33mmrunal[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)

_, term_width = shutil.get_terminal_size()
#_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time


def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f

In [3]:
warnings.simplefilter(action='ignore', category=FutureWarning)

pkgpath = './'
save_path = './results/'

if os.path.isdir(save_path) == False:
    os.makedirs(save_path)

sys.path.append(pkgpath)



device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
        transforms.RandomCrop(64, padding=8),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

train_dir = './tiny-imagenet-200/train'

trainset = torchvision.datasets.ImageFolder(
    train_dir, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=100, shuffle=True, num_workers=0)

test_dir = './tiny-imagenet-200/val/images'
testset = torchvision.datasets.ImageFolder(
    test_dir, transform=transform_test) 
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=0)
classes = 200
img_size = 64

# Model
print('==> Building model..')
import torchvision.models as models

from torchvision.models.resnet import ResNet, Bottleneck

num_classes = 200

print(ResNet)
class ModelParallelResNet50(ResNet):
    def __init__(self, *args, **kwargs):
        super(ModelParallelResNet50, self).__init__(
            Bottleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs)

        self.seq1 = nn.Sequential(
            self.conv1,
            self.bn1,
            self.relu,
            self.maxpool,

            self.layer1,
            self.layer2
        ).to('cuda:0')

        self.seq2 = nn.Sequential(
            self.layer3,
            self.layer4,
            self.avgpool,
        ).to('cuda:1')

        self.fc.to('cuda:1')

    def forward(self, x):
        x = self.seq2(self.seq1(x).to('cuda:1'))
        return self.fc(x.view(x.size(0), -1))
    
net = models.resnet50()


#Finetune Final few layers to adjust for tiny imagenet input
net.avgpool = nn.AdaptiveAvgPool2d(1)
net.fc.out_features = 200
print

net = net.to(device)
# if device == 'cuda':
#     net = torch.nn.DataParallel(net)
#     cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=0.0005)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
print(net)
#%%


==> Preparing data..
==> Building model..
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2

In [None]:
# Training
wandb.watch(net)


epochs = 90
def train(epoch):
    print('Epoch:{0}/{1}'.format(epoch, epochs))
    net.train()
    
    train_loss = 0 
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)        
        loss = criterion(outputs, targets)       
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'train_Loss: %.3f | train_Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

    return train_loss/(batch_idx+1), 100.*correct/total

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    acc_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            acc_list.append(100.*correct/total)
            
            progress_bar(batch_idx, len(testloader), 'test_Loss: %.3f | test_Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            print('>>>best acc: {0}, mean: {1}, std: {2}'.format(best_acc, round(np.mean(acc_list), 2), round(np.std(acc_list), 2)))
            
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir(save_path+'/checkpoint'):
            os.mkdir(save_path+'/checkpoint')
        torch.save(state, save_path+'/checkpoint/ckpt.pth')
        best_acc = acc
        print('>>>best acc:', best_acc)
    
    return test_loss/(batch_idx+1), 100.*correct/total, best_acc

test_loss = 0
test_list = []
train_list = []
epoch_list = []
train_acc_list = []
test_acc_list = []
for epoch in range(start_epoch, start_epoch+epochs):
   
    epoch_list.append(epoch)
    
    train_loss, train_acc = train(epoch)
    train_list.append(train_loss)
    train_acc_list.append(train_acc)
    
    test_loss, test_acc, best_acc = test(epoch)
    test_list.append(test_loss)
    test_acc_list.append(test_acc)
    
    epoch_line = 'epoch: {0}/ total epoch: {1} '.format(epoch, epochs) 
    best_acc_line = 'best_acc: {0} '.format(best_acc)
    accuracy_line = 'train_acc: {0} %, test_acc: {1} % '.format(train_acc, test_acc)
    loss_line = 'train_loss: {0},e test_loss: {1} '.format(train_loss, test_loss)
    wandb.log({"train accuracy" : train_acc, "test accuracy" : test_acc, "train_loss" : train_loss, "test_loss" : test_loss}, step=epoch+1)

    if epoch % 1 == 0:
        plt.figure()
        ax1 = plt.subplot(2, 1, 1)
        ax1.plot(epoch_list, train_list, c = 'blue', label = 'train loss')
        ax1.plot(epoch_list, test_list, c = 'red', label = 'test loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        ax1.legend(loc=0)
        
        ax2 = plt.subplot(2, 1, 2)
        ax2.plot(epoch_list, train_acc_list, c = 'blue', label = 'train acc')
        ax2.plot(epoch_list, test_acc_list, c = 'red', label = 'test acc')
        plt.ylabel('acc')
        plt.xlabel('epoch')
        ax2.legend(loc=0)
        
        plt.savefig(save_path+'/train_history.png')

    
    with open(save_path+'/logs.txt', 'a') as f:
        f.write(epoch_line + best_acc_line + accuracy_line + loss_line + '\n')
    scheduler.step()


Epoch:0/90
>>>best acc: 0, mean: 7.0, std: 0.0................................]  Step: 211ms | Tot: 0ms | test_Loss: 4.341 | test_Acc: 7.000% (7/100) 1/100 
>>>best acc: 0, mean: 6.75, std: 0.25..............................]  Step: 109ms | Tot: 110ms | test_Loss: 4.518 | test_Acc: 6.500% (13/200) 2/100 
>>>best acc: 0, mean: 5.94, std: 1.16..............................]  Step: 108ms | Tot: 218ms | test_Loss: 4.532 | test_Acc: 4.333% (13/300) 3/100 
>>>best acc: 0, mean: 5.83, std: 1.02..............................]  Step: 110ms | Tot: 329ms | test_Loss: 4.507 | test_Acc: 5.500% (22/400) 4/100 
>>>best acc: 0, mean: 5.55, std: 1.08..............................]  Step: 107ms | Tot: 436ms | test_Loss: 4.537 | test_Acc: 4.400% (22/500) 5/100 
>>>best acc: 0, mean: 5.23, std: 1.21..............................]  Step: 120ms | Tot: 557ms | test_Loss: 4.547 | test_Acc: 3.667% (22/600) 6/100 
>>>best acc: 0, mean: 5.14, std: 1.14..............................]  Step: 112ms | Tot: 669ms | t

>>>best acc: 3.3, mean: 12.77, std: 6.13...........................]  Step: 83ms | Tot: 240ms | test_Loss: 4.213 | test_Acc: 7.750% (31/400) 4/100 
>>>best acc: 3.3, mean: 12.02, std: 5.68...........................]  Step: 78ms | Tot: 319ms | test_Loss: 4.142 | test_Acc: 9.000% (45/500) 5/100 
>>>best acc: 3.3, mean: 11.79, std: 5.21...........................]  Step: 98ms | Tot: 417ms | test_Loss: 4.066 | test_Acc: 10.667% (64/600) 6/100 
>>>best acc: 3.3, mean: 12.23, std: 4.94...........................]  Step: 80ms | Tot: 498ms | test_Loss: 3.897 | test_Acc: 14.857% (104/700) 7/100 
>>>best acc: 3.3, mean: 12.48, std: 4.67...........................]  Step: 77ms | Tot: 575ms | test_Loss: 3.947 | test_Acc: 14.250% (114/800) 8/100 
>>>best acc: 3.3, mean: 12.5, std: 4.41............................]  Step: 78ms | Tot: 654ms | test_Loss: 4.037 | test_Acc: 12.667% (114/900) 9/100 
>>>best acc: 3.3, mean: 12.39, std: 4.19...........................]  Step: 84ms | Tot: 739ms | test_Loss

>>>best acc: 10.01, mean: 17.58, std: 5.81.........................]  Step: 81ms | Tot: 607ms | test_Loss: 3.995 | test_Acc: 17.111% (154/900) 9/100 
>>>best acc: 10.01, mean: 17.48, std: 5.52.........................]  Step: 74ms | Tot: 681ms | test_Loss: 4.001 | test_Acc: 16.600% (166/1000) 10/100 
>>>best acc: 10.01, mean: 17.36, std: 5.27.........................]  Step: 73ms | Tot: 754ms | test_Loss: 4.032 | test_Acc: 16.091% (177/1100) 11/100 
>>>best acc: 10.01, mean: 17.42, std: 5.05.........................]  Step: 73ms | Tot: 828ms | test_Loss: 3.899 | test_Acc: 18.083% (217/1200) 12/100 
>>>best acc: 10.01, mean: 17.39, std: 4.86.........................]  Step: 74ms | Tot: 903ms | test_Loss: 3.971 | test_Acc: 17.000% (221/1300) 13/100 
>>>best acc: 10.01, mean: 17.3, std: 4.69..........................]  Step: 77ms | Tot: 980ms | test_Loss: 4.017 | test_Acc: 16.214% (227/1400) 14/100 
>>>best acc: 10.01, mean: 17.19, std: 4.55.........................]  Step: 76ms | Tot: 1s