In [None]:
!rm -rf FashionMeter/
!git clone https://github.com/yalibina/FashionMeter.git

In [None]:
!cd FashionMeter && pip install -r requirements.txt

In [None]:
!cd FashionMeter/src/dataload &&  ./download.sh

In [None]:
!pwd
!ls

In [None]:
%cd FashionMeter

In [None]:
import yaml
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import time
from tqdm.notebook import tqdm
import wandb
import torch
from src.dataload.dataset import (
    train_dataloader,
    val_dataloader,
    class_weights,
    ids2label,
    label2ids,
    NUM_LABELS
)
import torch.nn.utils.prune as prune
import torch.nn as nn
from src.models.vit import LitViT


In [None]:
pl.seed_everything(42)

In [None]:
!ls

In [None]:
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

In [None]:
N_EPOCHS = config['N_EPOCHS']
LR = config['LR']
WD = config['WD']
CHECKPOINT_DIR = config['CHECKPOINT_DIR']
PROJECT_NAME = config['PROJECT_NAME']
MODEL_NAME = config['MODEL_NAME']

print(config)

In [None]:
wandb.finish()
run = wandb.init()
artifact = run.use_artifact('mmls05/FashionMeter/10epochs_lr1e-05_wd0.01:v0', type='model')
artifact_dir = artifact.download()

In [None]:
ckpt_path = '/content/FashionMeter/artifacts/10epochs_lr1e-05_wd0.01:v0/model.ckpt'

In [None]:
lit_model = LitViT(
    num_labels=NUM_LABELS,
    id2label=ids2label,
    label2id=label2ids,
    class_weights=class_weights,
    lr=LR,
    weight_decay=WD,
)
lit_model.load_state_dict(torch.load(ckpt_path)['state_dict'])

## Evaluate base model

In [None]:
pl.seed_everything(42)
torch.manual_seed(42)
lit_model.to('cuda')
lit_model.eval()
all_preds = []
all_labels = []

start = time.time()
with torch.no_grad():
    for batch in tqdm(val_dataloader):
        pixel_values = batch['pixel_values'].to(lit_model.device)
        labels = batch['labels'].to(lit_model.device)
        logits = lit_model(pixel_values)
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

end = time.time()

print("Time for inference:", end - start)

from sklearn.metrics import classification_report
print(classification_report(all_labels, all_preds))

## Prune model

In [None]:
def apply_structured_pruning(model, amount=0.2):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Prune entire output neurons (dim=0)
            prune.ln_structured(module, name='weight', amount=amount, n=1, dim=0)
            prune.remove(module, 'weight')  # remove reparameterization to finalize pruning
    return model

pruned_model = apply_structured_pruning(lit_model.vit, amount=0.05)

In [None]:
lit_pruned = LitViT(
    num_labels=NUM_LABELS,
    id2label=ids2label,
    label2id=label2ids,
    class_weights=class_weights,
    lr=LR,
    weight_decay=WD
)
lit_pruned.vit = pruned_model

In [None]:
lit_pruned.device

In [None]:
from pytorch_lightning.utilities.model_summary import ModelSummary
ModelSummary(lit_pruned)

In [None]:
lit_pruned.to('cuda:0')

In [None]:
lit_pruned.eval()
all_preds = []
all_labels = []
torch.manual_seed(42)
start = time.time()
with torch.no_grad():
    for batch in tqdm(val_dataloader):
        pixel_values = batch['pixel_values'].to(lit_pruned.device)
        labels = batch['labels'].to(lit_pruned.device)
        logits = lit_pruned(pixel_values)
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

end = time.time()

print("Time for inference:", end - start)

# Optionally: evaluate
from sklearn.metrics import classification_report
print(classification_report(all_labels, all_preds))

## Fine-tune pruned model

In [None]:
trainer = trainer = pl.Trainer(
    log_every_n_steps=10,
    max_epochs=3,
    default_root_dir=CHECKPOINT_DIR,
    deterministic=True,
    precision="16-mixed",
)
trainer.fit(lit_pruned, train_dataloader, val_dataloader)

In [None]:
trainer = trainer = pl.Trainer(
    log_every_n_steps=10,
    max_epochs=2,
    default_root_dir=CHECKPOINT_DIR,
    deterministic=True,
    precision="16-mixed",
)
trainer.fit(lit_pruned, train_dataloader, val_dataloader)

In [None]:
lit_pruned.device

In [None]:
# Measure on GPU
lit_pruned.to('cuda:0')
lit_pruned.eval()
all_preds = []
all_labels = []
torch.manual_seed(42)
start = time.time()
with torch.no_grad():
    for batch in tqdm(val_dataloader):
        pixel_values = batch['pixel_values'].to(lit_pruned.device)
        labels = batch['labels'].to(lit_pruned.device)
        logits = lit_pruned(pixel_values)
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

end = time.time()

print("Time for inference:", end - start)

# Optionally: evaluate
from sklearn.metrics import classification_report
print(classification_report(all_labels, all_preds))

In [None]:
# Measure on CPU
lit_pruned.to('cpu')
lit_pruned.eval()
all_preds = []
all_labels = []
torch.manual_seed(42)
start = time.time()
with torch.no_grad():
    for batch in tqdm(val_dataloader):
        pixel_values = batch['pixel_values'].to(lit_pruned.device)
        labels = batch['labels'].to(lit_pruned.device)
        logits = lit_pruned(pixel_values)
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

end = time.time()

print("Time for inference:", end - start)

# Optionally: evaluate
from sklearn.metrics import classification_report
print(classification_report(all_labels, all_preds))

In [None]:
torch.save(lit_pruned, "lit_vit_pruned.pth")