In [1]:
import torch
from PIL import Image
from project.utils import MultiModalConfig, MultiModalModel


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
resnet_paths = {'resnet18': './models/baseresNet18/',
                'resnet50': './models/baseresNet50/'}
vit_paths = {'vit': './models/vitRater'}
bert_paths = {'bert': './models/promptBert',
              "roberta": './models/promptRoberta/'}

config = MultiModalConfig(
    resnet_model_paths=resnet_paths,
    vit_model_paths=vit_paths,
    nlp_transformers_model_paths=bert_paths
)



In [3]:
config

MultiModalConfig {
  "class_names": [
    "PG",
    "PG-13",
    "R",
    "X",
    "XXX"
  ],
  "id2label": {
    "0": "PG",
    "1": "PG-13",
    "2": "R",
    "3": "X",
    "4": "XXX"
  },
  "label2id": {
    "PG": 0,
    "PG-13": 1,
    "R": 2,
    "X": 3,
    "XXX": 4
  },
  "model_type": "vision-text-dual-encoder",
  "nlp_transformers_model_paths": {
    "bert": "./models/promptBert",
    "roberta": "./models/promptRoberta/"
  },
  "resnet_model_paths": {
    "resnet18": "./models/baseresNet18/",
    "resnet50": "./models/baseresNet50/"
  },
  "transformers_version": "4.41.2",
  "vit_model_paths": {
    "vit": "./models/vitRater"
  }
}

In [4]:
model = MultiModalModel.from_pretrained(model_name_or_path='models', config=config)


In [5]:
model.resnet_models.resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [6]:
model.resnet_models.resnet50

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [7]:
model.nlp_transformers_models

ModuleDict(
  (bert): DistilBertForSequenceClassification(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (ffn):

In [8]:
model.fc

Linear(in_features=3200, out_features=5, bias=True)

In [9]:
model.resnet_models

ModuleDict(
  (resnet18): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [10]:
config.save_pretrained('new_models') 
model.save_pretrained('new_models')

In [11]:
prompt = "Movie Poster page, (promotional poster), Ruby Rose, 1female, solo, humanoid android, silver eyes, singer's uniform, headset, WeirdOutfit style, concert, Nippon Budokan, glowneon, glowing, sparks, lightning, shadow minimalism, (best quality), (masterpiece), detailed, beautiful detailed eyes, perfect anatomy, perfect body, perfect face, perfect hair, perfect legs, perfect hands, perfect arms, perfect fingers, detailed hair, detailed face, detailed eyes, detailed clothes, detailed skin, ultra-detailed, (full body), (upper body), (top quality), pop art, extremely detailed, extremely detailed CG, (high resolution), highly detailed, (high quality), (perfect quality), (glitchcore colors)"
img = Image.open('data/4703CF410F5635CCEB29A525F9CF339187C512F2508DFBE710D79BCF9C464DC0.jpeg')

In [12]:

# Run inference
model.eval()
with torch.no_grad():
    output = model(image=img, text=prompt)
    predicted_class = torch.argmax(output, dim=1).item()

print(f"Predicted Class: {predicted_class}")

torch.Size([1, 3200])
Predicted Class: 3
