In [3]:
from src.dataset import get_datasets
import torch
from transformers import BertTokenizer
from src.eval import evaluate_model
from src.models import ImageCaptioningModelTransformer  # твоя модель
import torch.nn as nn
from src.dataset import get_datasets
import timm
from torch.utils.data import DataLoader

In [4]:
def load_model(checkpoint_path, train, device="cuda"):

    vit_model = timm.create_model('vit_small_patch16_224', pretrained=True)
    vit_model.head = torch.nn.Identity()
    for p in vit_model.parameters():
        p.requires_grad = False

    vit_model.eval()

    model = ImageCaptioningModelTransformer(train.tokenizer.vocab_size, vit_model)

    checkpoint = torch.load(checkpoint_path, map_location=device)

    state_dict = checkpoint["model_state"] if "model_state" in checkpoint else checkpoint["model_state_dict"]
    model.load_state_dict(state_dict, strict=False)

    model.to(device)
    model.eval()

    return model

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

checkpoint = "./checkpoint/run1_checkpoint_epoch_19.pth"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

train, _, test = get_datasets()
test_loader = DataLoader(test, batch_size=8, shuffle=False)

model = load_model(checkpoint, train, device)

In [7]:
from torchvision import transforms
from PIL import Image

image = Image.open("./test_img/girl.jpg").convert("RGB")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
image = transform(image).unsqueeze(0).to(device)  # [1, 3, 224, 224]

In [8]:
def generate_square_subsequent_mask(sz):
    """
    Создаёт маску для Transformer decoder,
    чтобы предотвратить "заглядывание" на будущие токены
    """
    mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
    return mask

In [9]:
max_len = 20  # максимальная длина генерируемой подписи
generated = [tokenizer.cls_token_id]  # стартовый токен

for _ in range(max_len):
    # создаём маску для текущей последовательности
    tgt_mask = generate_square_subsequent_mask(len(generated)).to(device)

    # превращаем текущие токены в тензор
    inp = torch.tensor(generated).unsqueeze(0).to(device)  # [1, seq_len]

    # forward pass через модель
    output = model(image, inp, tgt_mask)  # [1, seq_len, vocab_size]

    # берём последний токен (самый вероятный)
    next_token = output.argmax(-1)[0, -1].item()

    # добавляем его в последовательность
    generated.append(next_token)

    # если встретили конец предложения, останавливаем генерацию
    if next_token == tokenizer.sep_token_id:
        break

# преобразуем токены обратно в текст
caption = tokenizer.decode(generated, skip_special_tokens=True)
print("Generated caption:", caption)


Generated caption: felix sneakersronezzle sliced colliery ter patterson 山 ය fr clarify glenn ter patterson 山 ය fr clarify glenn


In [10]:
print(model.word_embedding.weight.shape)
print(tokenizer.vocab_size)


torch.Size([30522, 256])
30522


In [3]:
train[2]

/home/sasha/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/1000268201_693b08cb0e.jpg


(tensor([[[ 0.4851,  0.4508,  0.3138,  ..., -1.8953, -1.8610, -1.8097],
          [ 0.4851,  0.4508,  0.3309,  ..., -1.8439, -1.9295, -1.5699],
          [ 0.5022,  0.4508,  0.3652,  ..., -1.7412, -1.8610, -1.9295],
          ...,
          [ 1.3927,  1.0502,  1.3070,  ..., -1.0390, -1.1418,  0.0056],
          [ 1.6495,  1.3070,  1.1529,  ..., -1.0904, -1.1589,  0.2967],
          [ 1.4440,  1.3755,  1.0331,  ..., -1.0904, -1.1760,  0.5022]],
 
         [[ 0.8354,  0.8354,  0.6779,  ..., -1.5805, -1.6506, -1.5630],
          [ 0.8179,  0.8354,  0.6779,  ..., -1.7381, -1.7206, -1.2129],
          [ 0.8529,  0.8354,  0.6779,  ..., -1.7206, -1.8081, -1.6155],
          ...,
          [ 0.5028,  0.2752,  0.5378,  ..., -0.0749, -0.2150,  0.7654],
          [ 0.8880,  0.8880,  0.4503,  ..., -0.1099, -0.2850,  0.9055],
          [ 0.3102,  0.2752, -0.1099,  ..., -0.1099, -0.3025,  1.1506]],
 
         [[ 0.6879,  0.6705,  0.4962,  ..., -1.6476, -1.6650, -1.5953],
          [ 0.5834,  0.5834,

In [None]:
from src.train import Trainer