In [None]:
%cd ..
!rm -rf FashionMeter/
!git clone https://github.com/yalibina/FashionMeter.git

In [None]:
!cd FashionMeter && pip install -r requirements.txt

In [None]:
!cd FashionMeter/src/dataload &&  ./download.sh

In [None]:
!pwd
!ls

In [None]:
%cd FashionMeter

In [None]:
import yaml
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import time
from tqdm.notebook import tqdm
import wandb
import torch
from src.dataload.dataset import (
    train_dataloader,
    val_dataloader,
    class_weights,
    ids2label,
    label2ids,
    NUM_LABELS
)
from src.models.vit import LitViT, LitViTQuantized
from pytorch_lightning.utilities.model_summary import ModelSummary


In [None]:
pl.seed_everything(42)

In [None]:
%cd FashionMeter

In [None]:
!ls

In [None]:
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

In [None]:
N_EPOCHS = config['N_EPOCHS']
LR = config['LR']
WD = config['WD']
CHECKPOINT_DIR = config['CHECKPOINT_DIR']
PROJECT_NAME = config['PROJECT_NAME']
MODEL_NAME = config['MODEL_NAME']

print(config)

In [None]:
run = wandb.init()
artifact = run.use_artifact('mmls05/FashionMeter/10epochs_lr1e-05_wd0.01:v0', type='model')
artifact_dir = artifact.download()

## Evaluate base model

In [None]:
lit_model = LitViT(
    num_labels=NUM_LABELS,
    id2label=ids2label,
    label2id=label2ids,
    class_weights=class_weights,
    lr=LR,
    weight_decay=WD,
)
lit_model.load_state_dict(torch.load('/content/FashionMeter/artifacts/10epochs_lr1e-05_wd0.01:v0/model.ckpt')['state_dict'])

In [None]:
lit_model.to('cuda')
lit_model.eval()
all_preds = []
all_labels = []
torch.manual_seed(42)
start = time.time()
with torch.no_grad():
    for batch in tqdm(val_dataloader):
        pixel_values = batch['pixel_values'].to(lit_model.device)
        labels = batch['labels'].to(lit_model.device)
        logits = lit_model(pixel_values)
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

end = time.time()

print("Time for inference:", end - start)

from sklearn.metrics import classification_report
print(classification_report(all_labels, all_preds))

## Evaluate quantization

In [None]:
lit_model = LitViT(
    num_labels=NUM_LABELS,
    id2label=ids2label,
    label2id=label2ids,
    class_weights=class_weights,
    lr=LR,
    weight_decay=WD,
)
lit_model.load_state_dict(torch.load('/content/FashionMeter/artifacts/10epochs_lr1e-05_wd0.01:v0/model.ckpt')['state_dict'])

hf_model = lit_model.vit

quantized_model = torch.quantization.quantize_dynamic(hf_model, {torch.nn.Linear}, dtype=torch.qint8)

lit_quant = LitViTQuantized(quantized_model=quantized_model, num_labels=lit_model.num_labels)


In [None]:
lit_quant.device

In [None]:
ModelSummary(lit_quant)

In [None]:
lit_quant.eval()
all_preds = []
all_labels = []
torch.manual_seed(42)
start = time.time()
with torch.no_grad():
    for batch in tqdm(val_dataloader):
        pixel_values = batch['pixel_values'].to(lit_quant.device)
        labels = batch['labels'].to(lit_quant.device)
        logits = lit_quant(pixel_values)
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

end = time.time()

print("Time for inference:", end - start)

# Optionally: evaluate
from sklearn.metrics import classification_report
print(classification_report(all_labels, all_preds))

In [None]:
torch.save(lit_quant, "lit_vit_quantized_full.pth")