In [2]:
'''
 * The Tag2Text Model
 * Written by Xinyu Huang
 * Edited by Jungwook Seo
'''

import os
import json
import glob
import gc
import itertools

import torch
from tqdm import tqdm

from PIL import Image
from ram.models import tag2text
from ram import inference_tag2text as inference
from ram import get_transform


In [9]:
        
        
class Bias2Tag():
    def __init__(self, 
                 gpu_num: int, 
                 dataset: str, 
                 class_name: dict[str: str],
                 conflict_ratio: str, 
                 root_path: str,
                 pretrained_path: str,
                 tag2text_thres: float,
                 image_size=224):
        self.root_path = root_path
        self.conflict_ratio = conflict_ratio
        self.dataset = dataset
        self.pretrained_path = os.path.join(pretrained_path, 'tag2text', 'tag2text_swin_14m.pth')
        self.image_size = image_size
        self.tag2text_thres = tag2text_thres
        self.device = torch.device(f'cuda:{str(gpu_num)}' if torch.cuda.is_available() else 'cpu')
        self.tag2text_model = None
        self.class_name = class_name

    def load_model(self):
        self.tag2text_model = tag2text(pretrained=self.pretrained_path,
                                       image_size=self.image_size,
                                       vit='swin_b')
        self.tag2text_model.thres = self.tag2text_thres  # thres for tagging
        self.tag2text_model.eval()
        self.tag2text_model = self.tag2text_model.to(self.device)
        print(f"Tag2Text has been loaded. Device: {self.device}")

    def off_model(self):
        del self.tag2text_model
        torch.cuda.empty_cache()
        gc.collect()
        self.tag2text_model = None

    def generate_tag_json(self):
        # Load tag2text.
        # if self.tag2text_model == None: self.load_model()

        # Generate tags.json.
        transform = get_transform(dataset=self.dataset,
                                  image_size=self.image_size)
        
        # Inference tags and caption.
        path1 = '/mnt/sdc/Debiasing/benchmarks/bffhq/0.5pct/align/0/62242_0_0.png'
        path2 = '/mnt/sdc/Debiasing/benchmarks/bffhq/0.5pct/align/0/62387_0_0.png'
        image1 = transform(Image.open(path1)).unsqueeze(0).to(self.device)
        image2 = transform(Image.open(path2)).unsqueeze(0).to(self.device)
        return image1, image2
        # res = inference(image, self.tag2text_model)
        
        
bias2tag = Bias2Tag(gpu_num=6,
                    dataset='bffhq',
                    conflict_ratio='0.5',
                    class_name={
                                '0': 'young person',
                                '1': 'old person'},
                    root_path='/mnt/sdc/Debiaisng',
                    pretrained_path='/mnt/sdc/Debiasing/pretrained',
                    tag2text_thres=0.68)

In [10]:
a, b = bias2tag.generate_tag_json()

In [28]:
torch.concat([a,b]).size()

torch.Size([2, 3, 224, 224])

In [16]:
bias2tag.load_model()

/encoder/layer/0/crossattention/self/query is tied
/encoder/layer/0/crossattention/self/key is tied
/encoder/layer/0/crossattention/self/value is tied
/encoder/layer/0/crossattention/output/dense is tied
/encoder/layer/0/crossattention/output/LayerNorm is tied
/encoder/layer/0/intermediate/dense is tied
/encoder/layer/0/output/dense is tied
/encoder/layer/0/output/LayerNorm is tied
/encoder/layer/1/crossattention/self/query is tied
/encoder/layer/1/crossattention/self/key is tied
/encoder/layer/1/crossattention/self/value is tied
/encoder/layer/1/crossattention/output/dense is tied
/encoder/layer/1/crossattention/output/LayerNorm is tied
/encoder/layer/1/intermediate/dense is tied
/encoder/layer/1/output/dense is tied
/encoder/layer/1/output/LayerNorm is tied
--------------
/mnt/sdc/Debiasing/pretrained/tag2text/tag2text_swin_14m.pth
--------------
Position interpolate visual_encoder.layers.0.blocks.0.attn.relative_position_bias_table from 23x23 to 13x13
Position interpolate visual_enc

In [25]:
res = inference(b, bias2tag.tag2text_model)

In [26]:
res

('song | hat | woman | band | stage | microphone | person | smile | sing | wear | singe | perform',
 None,
 'a woman wearing a hat sings a song with a band performing on stage')