In [1]:
import torch
import torch.nn as nn
from torchsummary import summary
from timm.models import create_model
import resmlp
from utils import AverageMeter, accuracy

from tqdm import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
#resmlp = create_model('resmlp_24', num_classes=NUM_CLASSES).cuda()
#summary(resmlp, (3, 224, 224))

In [3]:
def evaluate_model(model, test_loader, device, criterion=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    
    model.eval()
    model.to(device)
    
    for inputs, labels in tqdm(test_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        
        if criterion is not None:
            loss = criterion(outputs, labels)
        else:
            loss = 0
        
        
        prec1, prec5 = accuracy(outputs.data, labels.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))
    
    return losses.avg, top1.avg, top5.avg

In [4]:
from datasets import build_dataset
from timm.data import Mixup
import numpy as np
import torch.backends.cudnn as cudnn
from timm.loss import SoftTargetCrossEntropy
import torch.optim as optim

## Set Parameters

In [5]:
INPUT_SIZE = 224

BATCH_SIZE = 32
EPOCHS = 100
LR = 0.003

DICT_PATH = 'ResMLP_S24_ReLU_99dense.pth' 
DATA_DIR = 'E:\datasets'
#WORKERS = 8
WORKERS = 0

## Train Process

In [6]:
# Data Parallel Training

# device = CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# set seed
seed = 336 # args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

# check for best cudnn ops before training starts.
cudnn.benchmark = True

# build train/val dataset
dataset_train, NUM_CLASSES = build_dataset(is_train=True, name='imagenet2012', root=DATA_DIR, input_size=INPUT_SIZE)
dataset_val, _ = build_dataset(is_train=False, name='imagenet2012', root=DATA_DIR, input_size=INPUT_SIZE)

# create sampler (if dataset from tfds, can't apply sampler) (distributed ver. to be done)
# sampler_train = torch.utils.data.RandomSampler(dataset_train)
# sampler_val = torch.utils.data.SequentialSampler(dataset_val)

# build up dataloader
data_loader_train = torch.utils.data.DataLoader(
    dataset_train, #sampler=sampler_train,
    batch_size=BATCH_SIZE,
    num_workers=WORKERS,
    pin_memory=True,
    drop_last=True,
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val, #sampler=sampler_val,
    batch_size=int(1.5 * BATCH_SIZE),
    num_workers=WORKERS,
    pin_memory=True,
    drop_last=False
)

# additional data augmentation (mixup)
mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=1.0, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=NUM_CLASSES)

# create model
model = create_model('resmlp_24', num_classes=NUM_CLASSES).cuda()
model.load_state_dict(torch.load(DICT_PATH), strict=False)

# decide layers to train

# set optimizer, lr_scheduler, criterion
#criterion = SoftTargetCrossEntropy() # for mixup
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),
                          lr=LR,
                          momentum=0.9,
                          weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[100, 150],
                                                     gamma=0.1,
                                                     last_epoch=-1)
    
# Evaluation
model.eval()
eval_loss, top1_acc, top5_acc = evaluate_model(model=model,
                                              test_loader=data_loader_val,
                                              device=device,
                                              criterion=criterion)
print("Epoch: {:02d} Eval Loss: {:.3f} Top1: {:.3f} Top5: {:.3f}".format(
    -1, eval_loss, top1_acc, top5_acc))

# start training

  4%|‚ñç         | 45/1042 [00:33<12:20,  1.35it/s] 


RuntimeError: DataLoader worker (pid(s) 28288) exited unexpectedly

## Quantize

In [None]:
resmlp.train()
resmlp.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
#model_fp32_fused = torch.quantization.fuse_modules(resmlp, [['blocks.mlp', 'relu']])
model_fp32_prepared = torch.quantization.prepare_qat(resmlp.cuda(), inplace=True)
#training_loop(model_fp32_prepared)
model_fp32_prepared.eval()
model_fp32_prepared.cpu()

#input_fp32 = torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda")
#model_fp32_prepared(input_fp32)

input_fp32 = torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cpu")
model_int8 = torch.quantization.convert(model_fp32_prepared)
model_int8(input_fp32)

In [None]:
#for name, value in resmlp.named_parameters():
#    print(name)