In [1]:
import torchvision.transforms.v2 as v2
import dl_toolbox.datasets as datasets
from torch.utils.data import Subset, RandomSampler
import torch


transform = v2.Compose([
    v2.Resize(size=(224, 224), antialias=True),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

NB_IMG = 45*700
dataset = datasets.Resisc('/data/NWPU-RESISC45', transform, 'all45')
trainset = Subset(dataset, indices=[i for i in range(NB_IMG) if 100<=i%700])
valset = Subset(dataset, indices=[i for i in range(NB_IMG) if 100>i%700])

train_loader = torch.utils.data.DataLoader(
    trainset,
    num_workers=6,
    pin_memory=True,
    sampler=RandomSampler(
        trainset,
        replacement=True,
        num_samples=5000
    ),
    drop_last=True,
    batch_size=4,
)
val_loader = torch.utils.data.DataLoader(
    valset,
    num_workers=6,
    pin_memory=True,
    shuffle=False,
    drop_last=True,
    batch_size=8,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn.functional as F 
def train(model, criterion, device, train_loader, optimizer, epoch):
    model.train()
    #optimizer.train()
    for batch_idx, batch in enumerate(train_loader):
        data, target = batch['image'], batch['label']
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [3]:
def test(model, criterion, optimizer, device, test_loader):
    model.eval()
    #optimizer.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch in test_loader:
            data, target = batch['image'], batch['label']
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [70]:
import timm
import torch.nn as nn
import minlora 
from functools import partial 
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

def get_lora_config(rank):
    return {  # specify which layers to add lora to, by default only add to linear layers
        nn.Linear: {
            "weight": partial(minlora.LoRAParametrization.from_linear, rank=rank),
        },
        #nn.Conv2d: {
        #    "weight": partial(minlora.LoRAParametrization.from_conv2d, rank=rank),
        #},
    }

class test_lora(nn.Module):
    def __init__(self, freeze, lora, rank):
        super().__init__()
        #self.model = timm.create_model(
        #    'efficientnet_b0',
        #    pretrained=True,
        #    num_classes=45
        #)
        #self.feature_extractor = create_feature_extractor(
        #    self.model,
        #    {'bn2.act': 'features'}
        #)
        self.model = timm.create_model(
            'vit_base_patch16_224',
            pretrained=True,
            global_pool='token',
            num_classes=45
        )       
        self.feature_extractor = create_feature_extractor(
            self.model,
            {'norm': 'features'}
        )
        
        print(get_graph_node_names(self.model))
        if freeze:
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
        if lora:
            cfg = get_lora_config(rank)
            minlora.add_lora(self.feature_extractor, lora_config=cfg)
        #self.head = nn.Linear(self.encoder.num_features, 45)
            
    def forward(self, x):
        return self.model.forward(x)
        #x = self.encoder.forward_features(x)
        #x = x[:, self.encoder.num_prefix_tokens:].mean(dim=1)
        #x = self.head(x)
        #return x

In [71]:
torch.manual_seed(1)

model = test_lora(freeze=True, lora=True, rank=4)
parameters = list(model.parameters())
trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
print(
    f"The model will start training with only {sum([int(torch.numel(p)) for p in trainable_parameters])} "
    f"trainable parameters out of {sum([int(torch.numel(p)) for p in parameters])}."
)
    
    
def name_is_lora(name):
    return (
        len(name.split(".")) >= 4
        and (name.split(".")[-4]) == "parametrizations"
        and name.split(".")[-1] in ["lora_A", "lora_B"]
    )

def name_is_head(name):
    return (name.split(".")[1]) == "classifier"

def get_params_by_name(model, print_shapes=False, name_filter=None):
    for n, p in model.named_parameters():
        if name_filter is None or name_filter(n):
            if print_shapes:
                print(n, p.shape)
            yield p

lora_parameters = get_params_by_name(model, name_filter=name_is_lora)
num_lora = sum([int(torch.numel(p)) for p in lora_parameters])
head_parameters = get_params_by_name(model, name_filter=name_is_head)
num_head = sum([int(torch.numel(p)) for p in head_parameters])
print(
    f"The model should start training with only {num_lora} trainable lora parameters, plus {num_head} head params = {num_head+num_lora}."
)

#criterion = nn.CrossEntropyLoss()
#optimizer = torch.optim.AdamW(
#    trainable_parameters,
#    lr=1e-3,
#)
#
#device = 'cuda'
#model = model.to(device)
#for epoch in range(1, 10):
#    train(model, criterion, device, train_loader, optimizer, epoch)
#    test(model, criterion, optimizer, device, val_loader)

(['x', 'patch_embed.getattr', 'patch_embed.getitem', 'patch_embed.getitem_1', 'patch_embed.getitem_2', 'patch_embed.getitem_3', 'patch_embed.eq', 'patch_embed._assert', 'patch_embed.eq_1', 'patch_embed._assert_1', 'patch_embed.proj', 'patch_embed.flatten', 'patch_embed.transpose', 'patch_embed.norm', 'pos_embed', 'cls_token', 'getattr', 'getitem', 'expand', 'cat', 'add', 'pos_drop', 'patch_drop', 'norm_pre', 'blocks.0.norm1', 'blocks.0.attn.getattr', 'blocks.0.attn.getitem', 'blocks.0.attn.getitem_1', 'blocks.0.attn.getitem_2', 'blocks.0.attn.qkv', 'blocks.0.attn.reshape', 'blocks.0.attn.permute', 'blocks.0.attn.unbind', 'blocks.0.attn.getitem_3', 'blocks.0.attn.getitem_4', 'blocks.0.attn.getitem_5', 'blocks.0.attn.q_norm', 'blocks.0.attn.k_norm', 'blocks.0.attn.scaled_dot_product_attention', 'blocks.0.attn.transpose', 'blocks.0.attn.reshape_1', 'blocks.0.attn.proj', 'blocks.0.attn.proj_drop', 'blocks.0.ls1', 'blocks.0.drop_path1', 'blocks.0.add', 'blocks.0.norm2', 'blocks.0.mlp.fc1', 