In [1]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models, datasets
from efficientnet_pytorch import EfficientNet
import collections
import sys
from mpemu import mpt_emu
from mpemu import qutils
import timm
sys.path.insert(0, "./ViT-pytorch")
from models.modeling import VisionTransformer, CONFIGS
import numpy as np
# Define paths and hyperparameters
BATCH_SIZE = 128
NUM_EPOCHS = 53
LEARNING_RATE = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
QUANTIZE = False
TEST_MODEL = True

# quantization
# torch.backends.cudnn.benchmark = True
# filter_module_types = [torch.nn.Conv2d, torch.nn.Linear] # Only quantizing convolution and linear modules
# exempt_modules = ["conv1","fc"]

print()
# define the EfficientNet models to use
# b0 max: 64
# b3 max: 32?
efficientnets = {
    # "b0": EfficientNet.from_pretrained("efficientnet-b0", num_classes=5),
    # "b1": EfficientNet.from_pretrained("efficientnet-b1", num_classes=5),
    # "b2": EfficientNet.from_pretrained("efficientnet-b2", num_classes=5),
    # "b3": EfficientNet.from_pretrained("efficientnet-b3", num_classes=5),
    # "b4": EfficientNet.from_pretrained("efficientnet-b4", num_classes=5),
    # "b5": EfficientNet.from_pretrained("efficientnet-b5", num_classes=5),
    # "b6": EfficientNet.from_pretrained("efficientnet-b6", num_classes=5),
    # "b7": EfficientNet.from_pretrained("efficientnet-b7", num_classes=5),
}

efficientnet_sizes = {
    "b0": 224,
    "b1": 240,
    "b2": 260,
    "b3": 300,
    "b4": 380,
    "b5": 456,
    "b6": 528,
    "b7": 600,
}
# model_name = "vit_base_r50_s16_384"
model_name = "R50-ViT-B_16"
# resnext = timm.create_model(model_name, pretrained=True, num_classes=5)
vit = VisionTransformer(CONFIGS[model_name], 384, zero_head=True, num_classes=5)
model_ckpt = torch.load("r50/checkpoint_8_noncont.pt", map_location="cpu")
vit.load_state_dict(model_ckpt["model_state_dict"] if "model_state_dict" in model_ckpt else model_ckpt, strict=False)
# vit.load_from(np.load("R50-ViT-B_16.npz"))
vit.to(DEVICE)
models = [vit]




In [2]:
timm.list_models("*vit*")

['convit_base',
 'convit_small',
 'convit_tiny',
 'crossvit_9_240',
 'crossvit_9_dagger_240',
 'crossvit_15_240',
 'crossvit_15_dagger_240',
 'crossvit_15_dagger_408',
 'crossvit_18_240',
 'crossvit_18_dagger_240',
 'crossvit_18_dagger_408',
 'crossvit_base_240',
 'crossvit_small_240',
 'crossvit_tiny_240',
 'gcvit_base',
 'gcvit_small',
 'gcvit_tiny',
 'gcvit_xtiny',
 'gcvit_xxtiny',
 'levit_128',
 'levit_128s',
 'levit_192',
 'levit_256',
 'levit_256d',
 'levit_384',
 'maxvit_base_224',
 'maxvit_large_224',
 'maxvit_nano_rw_256',
 'maxvit_pico_rw_256',
 'maxvit_rmlp_nano_rw_256',
 'maxvit_rmlp_pico_rw_256',
 'maxvit_rmlp_small_rw_224',
 'maxvit_rmlp_small_rw_256',
 'maxvit_rmlp_tiny_rw_256',
 'maxvit_small_224',
 'maxvit_tiny_224',
 'maxvit_tiny_pm_256',
 'maxvit_tiny_rw_224',
 'maxvit_tiny_rw_256',
 'maxvit_xlarge_224',
 'maxxvit_rmlp_nano_rw_256',
 'maxxvit_rmlp_small_rw_256',
 'maxxvit_rmlp_tiny_rw_256',
 'mobilevit_s',
 'mobilevit_xs',
 'mobilevit_xxs',
 'mobilevitv2_050',
 'mobi

In [None]:
from sklearn.model_selection import train_test_split
import time
from lion_pytorch import Lion
from tqdm import tqdm
from collections import OrderedDict


def train():
    # train the models and evaluate them on the validation set
    # for model_name, model in efficientnets.items():
    for model in models:
        # define the transform for data augmentation and resizing
        # image_size = efficientnet_sizes[model_name]
        image_size = 384
        # Define data augmentations and transformations
        train_transforms = transforms.Compose(
            [
                transforms.Resize((image_size, image_size)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(20),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        val_transforms = transforms.Compose(
            [
                transforms.Resize((image_size, image_size)),
                # transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )

        # Create train and validation datasets
        train_dataset = datasets.ImageFolder("/mnt/onetb/diabetic-retinopathy/FasterTransformer/examples/pytorch/vit/ViT-quantization/data_512_split/train", transform=train_transforms)
        val_dataset = datasets.ImageFolder("/mnt/onetb/diabetic-retinopathy/FasterTransformer/examples/pytorch/vit/ViT-quantization/data_512_split/val", transform=val_transforms)
        print(train_dataset.classes)

#         train_idx, val_idx = train_test_split(
#             list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets
#         )
#         train_dataset = torch.utils.data.Subset(dataset, train_idx)
#         val_dataset = torch.utils.data.Subset(dataset_val, val_idx)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=8,
            pin_memory=True,
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=8,
            pin_memory=True,
        )

        # Define model
#         model.to(DEVICE)

        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        # optimizer = Lion(model.parameters(), lr=1e-5, weight_decay=1e-2)
        #         optimizer = optim.Adam(model.parameters(), lr=1e-4)
        optimizer = torch.optim.SGD(model.parameters(),
                                lr=5e-4,
                                momentum=0.9,
                                weight_decay=0)

        list_exempt_layers = ["conv1", "fc"]
        #         if "resnet" in args.arch or "resnext" in args.arch:
        #             list_exempt_layers = ["conv1","fc"]
        #         elif args.arch == "vgg19_bn":
        #             list_exempt_layers = ["features.0", "classifier.6"]
        #         elif args.arch == "inception_v3":
        #             list_exempt_layers = ["Conv2d_1a_3x3.conv", "fc"]

        # quantization
        if QUANTIZE:
            model, emulator = mpt_emu.quantize_model(
                model,
                optimizer=optimizer,
                dtype="e4m3",
                device=DEVICE,
                verbose=True,
                list_exempt_layers=list_exempt_layers,
                list_layers_output_fused=[],
            )
            emulator.set_default_inference_qconfig()
        #             filter_module_types = [torch.nn.Conv2d, torch.nn.Linear] # Only quantizing convolution and linear modules
        #             exempt_modules = ["conv1","fc"]
        #             is_training = True
        #             wt_qconfig = qutils.TensorQuantConfig("e4m3", "rne")#, "per-channel")
        #             iact_qconfig   = qutils.TensorQuantConfig("e4m3", "rne")#, "per-tensor")
        #             emb_qconfig    = qutils.TensorQuantConfig("e4m3", "rne")#, "per-channel")
        #             qconfig = qutils.ModuleQuantConfig(wt_qconfig=wt_qconfig, iact_qconfig=iact_qconfig)
        #             model_qconfig_dict  = qutils.get_or_update_model_quant_config_dict(model, filter_module_types, qconfig,
        #                                                                         exempt_modules=exempt_modules)
        #             print("Model quantization configuration")
        #             for layer,qconfig in model_qconfig_dict.items():
        #                 print(layer, qconfig)
        #             print()

        #             qutils.reset_quantization_setup(model, model_qconfig_dict)
        #             qhooks = qutils.add_quantization_hooks(model, model_qconfig_dict, is_training=is_training)
        # test the model
        if TEST_MODEL:
            if QUANTIZE:
                eval_model = emulator.fuse_bnlayers_and_quantize_model(model)
            val_loss = 0.0
            val_acc = 0.0
            model.eval()
            with torch.no_grad():
                for images, labels in tqdm(val_loader):
                    images, labels = images.to(DEVICE), labels.to(DEVICE)
                    if QUANTIZE:
                        outputs = eval_model(images)
                    else:
                        outputs, _ = model(images)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item() * images.size(0)
                    _, predictions = torch.max(outputs, 1)
                    val_acc += torch.sum(predictions == labels.data)
            val_loss /= len(val_dataset)
            val_acc /= len(val_dataset)
            print(torch.cuda.memory_stats(DEVICE))
            print(f"Val   loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
            return

        # Train and validate the model
        for epoch in range(NUM_EPOCHS):
            start = time.time()
            start_total = time.time()
            print(f"Epoch {epoch+1}/{NUM_EPOCHS}")

            # Train the model
            model.train()
            train_loss = 0.0
            train_acc = 0.0
            for images, labels in tqdm(train_loader):
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs, _ = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                train_loss += loss.item() * images.size(0)
                _, predictions = torch.max(outputs, 1)
                train_acc += torch.sum(predictions == labels.data)
            train_loss /= len(train_dataset)
            train_acc /= len(train_dataset)
            print(f"Train loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
            print(f"Train time: {time.time() - start}")
            print(torch.cuda.memory_stats(DEVICE))

            start = time.time()
            # Validate the model
            if QUANTIZE:
                eval_model = emulator.fuse_bnlayers_and_quantize_model(model)
            val_loss = 0.0
            val_acc = 0.0
            with torch.no_grad():
                for images, labels in tqdm(val_loader):
                    images, labels = images.to(DEVICE), labels.to(DEVICE)
                    if QUANTIZE:
                        outputs = eval_model(images)
                    else:
                        outputs, _ = model(images)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item() * images.size(0)
                    _, predictions = torch.max(outputs, 1)
                    val_acc += torch.sum(predictions == labels.data)
            val_loss /= len(val_dataset)
            val_acc /= len(val_dataset)
            print(f"Val   loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

            # Save checkpoint
            checkpoint_path = f"checkpoint_{epoch+1}.pt"
            torch.save(
                {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "train_loss": train_loss,
                    "train_acc": train_acc,
                    "val_loss": val_loss,
                    "val_acc": val_acc,
                },
                checkpoint_path,
            )
            total_end = time.time() - start_total

            # Save loss and accuracy values to file
            with open("loss_acc.txt", "a") as file:
                file.write(
                    f"{model_name}, {train_loss:.4f}, {train_acc:.4f}, {val_loss:.4f}, {val_acc:.4f}, {epoch}, {BATCH_SIZE}, {total_end}\n"
                )

            print(f"Val and misc time: {time.time() - start}")
            print(f"Total time: {total_end}")


if __name__ == "__main__":
    train()

['0', '1', '2', '3', '4']


 96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊       | 53/55 [01:16<00:02,  1.28s/it]