In [1]:
from utils import create_text_embeddings, calc_ann_metric_rerank2
import pandas as pd
import cv2 as cv2
from tqdm.notebook import tqdm
import torch
from models import get_model2
from torchvision import transforms
from annoy import AnnoyIndex
from transformers import AutoTokenizer, AutoModel

In [2]:
company_data = pd.read_csv("company_data.csv")
comparable_data = pd.read_csv("comparable_data.csv")

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
default_transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [5]:
model_name = "cointegrated/rubert-tiny2"
text_model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_model.to(device)



BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(83828, 312, padding_idx=0)
    (position_embeddings): Embedding(2048, 312)
    (token_type_embeddings): Embedding(2, 312)
    (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-2): 3 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=312, out_features=312, bias=True)
            (key): Linear(in_features=312, out_features=312, bias=True)
            (value): Linear(in_features=312, out_features=312, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=312, out_features=312, bias=True)
            (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)

In [6]:
company_text_embeddings = create_text_embeddings("company_data.csv", text_model, tokenizer)
databse_text_embeddings = create_text_embeddings(
    "comparable_data.csv", text_model, tokenizer
)

  0%|          | 0/984 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


  0%|          | 0/13718 [00:00<?, ?it/s]

In [7]:
forest = AnnoyIndex(len(databse_text_embeddings[0]), metric="angular")
for i, item in tqdm(enumerate(databse_text_embeddings)):
    forest.add_item(i, item.cpu().detach().tolist())
forest_labels = comparable_data["target"].tolist()
forest.build(50)
forest.save("forest_berttiny_not_trainedv1.ann")

0it [00:00, ?it/s]

True

In [8]:
img_model = get_model2()
img_model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
calc_ann_metric_rerank2(
    k_img=10, k_text=40, forest=forest, query_tensors=company_text_embeddings,
    query_df=company_data, database_df=comparable_data, img_model=img_model,
    default_transform=default_transform, device=device
)

  0%|          | 0/984 [00:00<?, ?it/s]

velotrenazhery acc at 10: 0.9176
ellipticheskie_trenazhery acc at 10: 0.9786
begovye_dorozhki acc at 10: 0.994
multistantsii acc at 10: 0.3342
steppery acc at 10: 0.1618
trenazhery_na_svobodnykh_vesakh acc at 10: 0.3565
grebnye_trenazhery acc at 10: 0.536
Total acc at 10: 0.71
