In [None]:
# %pip install transformers bitsandbytes accelerate -qqq
# %pip install swifter pillow pytorchvideo einops peft hnswlib datasets

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, LlavaForConditionalGeneration,BitsAndBytesConfig, AutoProcessor
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
from transformers.cache_utils import Cache
from peft import LoraConfig, get_peft_model
from peft import prepare_model_for_kbit_training

import pandas as pd
import re
import swifter
from PIL import Image
import requests
from io import BytesIO
import cv2
import numpy as np
import imageio
from urllib.request import urlopen
from Encoder import MultiModalEncoder
from datetime import datetime, timedelta
import json

import pandas as pd
import warnings
warnings.filterwarnings("ignore")
pd.options.mode.chained_assignment = None

# Llava

In [None]:
class args:
    model_name = "llava-hf/llava-1.5-7b-hf" 
    quantisation_4_bit = False
    quantisation_8_bit = True
    
    batch_size = 1
    grad_acc_steps = 8
    device = 'cuda'

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(
    args.model_name,
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token
model = LlavaForConditionalGeneration.from_pretrained(
        args.model_name,
        # quantization_config=bnb_config if args.quantisation_4_bit else None,# 4-bit quantisation
        load_in_8bit = True if args.quantisation_8_bit else None, # 8-bit quantisation
        device_map={"": 0}, # Single GPU
        trust_remote_code=True
)

model

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [None]:
config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

model.config.use_cache = False
model = get_peft_model(model, config)
print_trainable_parameters(model)

In [None]:
processor =AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

In [None]:
# retrieve closest tweets by media embedding, likes and same company

# Analogy retriever

In [None]:
from companies import companies_arr
def get_img_url(string):
    url = re.findall(f"Url='(.+?)'",string)
    if url:return url[0]
    return


def get_vid_url(string):
    url = re.findall(r"VideoVariant\(contentType='video/mp4', url='(.+?)',",string)
    if url:return url[0]
    return


session = requests.Session()

def get_image_from_url(url, timeout=5):
    # Create a session object if it's not given
    with requests.Session() as session:
        try:
            with session.get(url, stream=True, timeout=timeout) as response:
                response.raise_for_status()  # Raises HTTPError for bad HTTP status codes
                if 'image' in response.headers.get('Content-Type', '').lower():
                    # Use BytesIO to load the image from the response content
                    image = Image.open(BytesIO(response.content))
                    
                    # Convert the image to RGB if it's not already in RGB mode
                    if image.mode not in ('RGB'):
                        image = image.convert('RGB')
                    
                    return image
                else:
                    return None
        except Exception as e:
            return None

def get_video_from_url(url):
    try:
        with urlopen(url) as response:
            video_data = response.read()
            video_bytes = BytesIO(video_data)
            return video_bytes
    except:
        return


def filter_date_top_k_likes(df, likes, input_date, inp_unseen_brand, k=5):
    input_date = datetime.strptime(input_date, '%Y-%m-%d %H:%M:%S')
    six_months = timedelta(days=30*6)
    start_date = input_date - six_months
    end_date = input_date
    df['date'] = pd.to_datetime(df['date'])
    
    filtered_df = df[(df['date'] >= start_date) & (df['date'] <= end_date)]
    
    if filtered_df.empty:
        print(f"No matching rows found for {input_date}")
        return None
    
    companies=[]
    for row in companies_arr:
        if inp_unseen_brand in row:
            companies=row
    
    filtered2_df = filtered_df[filtered_df['inferred company'].isin(companies)]
    
    filtered2_df['likes_distance'] = abs(filtered2_df['likes'] - likes)
    top_k_df = filtered2_df.nsmallest(k, 'likes_distance')
    
    return top_k_df.drop(columns=['likes_distance']).iloc[:k]



analogies = pd.read_csv('dump/analogies.csv')
training = pd.read_csv('dump/training.csv')



likes = 110
input_date = '2019-03-13 12:32:12'
inp_unseen_brand = 'sony'

retrieved = filter_date_top_k_likes(training, likes, input_date, inp_unseen_brand)

In [None]:

encoder = MultiModalEncoder('cuda')

"""
Example:
vid = get_video_from_url(r'https://video.twimg.com/amplify_video/106741146272395776/vid/240x240/hXX2ch5Jrk44xYW0.mp4?tag=8')
img = get_image_from_url(r'https://pbs.twimg.com/media/Eo8N3JLVoAAlDJT?format=jpg&name=small')

output = encoder({'image':img,'video':vid})
"""

In [None]:
from docarray.index import HnswDocumentIndex
from docarray import BaseDoc, DocList,DocVec
from docarray.typing import ImageUrl, VideoUrl, ID, NdArray
from typing import Optional


class Doc(BaseDoc):
    id: ID
    date: Optional[str] = None
    likes: Optional[int] = None
    content: Optional[str] = None
    username: Optional[str] = None
    media: Optional[str] = None
    inferred_company: Optional[str] = None
    img_url: Optional[ImageUrl] = None
    vid_url: Optional[VideoUrl] = None
    img_vector: NdArray[768]
    vid_vector: NdArray[768]
    text_vector: NdArray[768]

In [None]:
docs = DocList[Doc].load_binary('dump/analogies_embeddings.pickle', compress=None, protocol='pickle')
doc_index = HnswDocumentIndex[Doc](work_dir='./tmp')
doc_index.index(docs)

In [None]:
def process_list(lst, func):
    none_positions = [i for i, x in enumerate(lst) if x is None]
    processed_list = [x for x in lst if x is not None]
    processed_list = func(processed_list)
    for pos in none_positions:
        processed_list.insert(pos, None)
    return processed_list

import torch
@torch.inference_mode()
def _get_embedding(to_encode,key):
    encoded = list(encoder({key:to_encode})[key].pooler_output)
    encoded = [x.cpu().numpy() for x in encoded]
    return encoded

def get_embedding(to_encode,key):
    if not isinstance(to_encode, list): to_encode = [to_encode]
    return process_list(to_encode, lambda x: _get_embedding(x,key))

In [None]:
row = training.iloc[0]

def retrieve(row):
    retrieved = {}
    try:
        if row.img_url is not None:
            url = row.img_url
            img = get_image_from_url(url)
            embed = get_embedding([img],'image')[0]
            results = doc_index.find(query=embed,search_field='img_vector',limit=3)
        if not pd.isna(row.vid_url):
            url = row.vid_url
            vid = get_video_from_url(url)
            embed = get_embedding(vid,'video')[0]
            results = doc_index.find(query=embed,search_field='vid_vector',limit=3)
        retrieved['media'] = [{'content':doc.content, 'user':doc.username, 'likes':doc.likes, 'company':doc.inferred_company} for doc in results.documents]
    except:
        retrieved['media'] = []
    rtr = filter_date_top_k_likes(analogies, row.likes, str(row['date']), row['inferred company'],5)
    retrieved['likes_company'] = [{'content':doc.content, 'user':doc.username, 'likes':doc.likes, 'company':doc['inferred company']} for i,doc in rtr.iterrows()]
    return retrieved

In [None]:
rtr = retrieve(training.iloc[1])

In [None]:
prompt_template = """\


    <s>[INST] 
Using the given tweets as reference construct a tweet which has analogical similarity to those, conditioned on the fact that the post is extremely viral.

Reference tweets with similar likes: 

Tweet: ```What a great day to BE part of the #BTSARMY. Shop the deluxe version of <mention>'s new album: <hyperlink> <hyperlink>```
Username: Target
Company: Target

Tweet: ```Your grand ideas will never go off track. ðŸ’¡ #GalaxyNote20 #GalaxyxBTS <mention> ðŸ‘” Learn more: <hyperlink> <hyperlink>```
Username: Samsung
Company: Samsung


The tweet is written by the user spotify belong to spotify.

New Tweet:

Given the following analogies:

by similar likes and companies: 
{}

by media similarity:
{}




"""

In [None]:
def classify_likes(likes):
        if likes > 10000:
            return "Viral"
        elif likes > 1000:
            return "High Engagement"
        elif likes > 100:
            return "Moderate Engagement"
        else:
            return "Low Engagement"

In [None]:
def make_prompt(row):
    rtr = retrieve(row)
    template = '''\
Using the given tweets as reference construct a tweet which has analogical similarity to those.\

Reference tweets with similar likes and companies:
{likes_based}

Reference tweets with similar media:
{media_based}


Generated:
User: {username}
Likes: {likes}
Level: {likes_level}
Tweet: {tweet}
'''
    return template.format(
        likes_based = rtr['likes_company'],
        media_based = rtr['media'],
        tweet = row.content,
        username = row.username,
        likes = row.likes,
        likes_level = classify_likes(row.likes)
    )

# Training

In [None]:
import joblib
prompts = joblib.Parallel(n_jobs=-1,backend='threading')(joblib.delayed(make_prompt)(row) for i,row in training.iterrows())

In [None]:
N = 10
train_df = pd.DataFrame({'prompts':prompts[:10]})
test_df = pd.DataFrame({'prompts':prompts[10:]})

In [None]:
from datasets import Dataset as HFDataset

train_ds = HFDataset.from_pandas(train_df)
test_ds = HFDataset.from_pandas(test_df)

In [None]:
from transformers import TrainingArguments

training_arguments = TrainingArguments(
    output_dir="./results_latest",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    optim='paged_adamw_32bit',
    # save_steps=250,
    fp16=True,
    logging_steps=10,
    save_strategy="steps",       # Save the model checkpoint every logging step
    save_steps=25, 
    learning_rate=2e-4,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    report_to = 'wandb'
)

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset = test_ds,
    peft_config=config,
    dataset_text_field="prompts",
    max_seq_length=1024, # Adjust accordingly
    tokenizer=tokenizer,
    args=training_arguments,
    packing=True,
)

for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)

trainer.train()

model.save_pretrained("output_dir") # saves lora again

# Extention to Video and Image

In [None]:
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
from transformers.cache_utils import Cache

class MultiMediaConditionalGeneration(LlavaForConditionalGeneration):
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        pixel_values: torch.FloatTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        vision_feature_layer: Optional[int] = None,
        vision_feature_select_strategy: Optional[str] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is None:
            # 1. Extra the input embeddings
            inputs_embeds = self.get_input_embeddings()(input_ids)

            # 2. Merge text and images
            if pixel_values is not None and input_ids.shape[1] != 1:
                inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(
                    pixel_values, inputs_embeds, input_ids, attention_mask, position_ids
                )
                if labels is None:
                    labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
            else:
                # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
                # generation with cache
                if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
                    # Retrieve the first layer to inspect the logits and mask out the hidden states
                    # that are set to 0
                    first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0]
                    batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)
                    # Get the target length
                    target_seqlen = first_layer_past_key_value.shape[-1] + 1

                    extended_attention_mask = torch.ones(
                        (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
                        dtype=attention_mask.dtype,
                        device=attention_mask.device,
                    )

                    # Zero-out the places where we don't need to attend
                    extended_attention_mask[batch_index, non_attended_tokens] = 0

                    attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
                    position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1

        outputs = self.language_model(
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        logits = outputs[0]

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            if attention_mask is not None:
                shift_attention_mask = attention_mask[..., 1:]
                shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
                shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
            else:
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
            )

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return LlavaCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )