In [1]:
!pip install -q -U transformers peft accelerate optimum einops


In [2]:
!pip install accelerate -U



In [3]:
import os
import requests
import urllib.parse as parse
import numpy as np
from PIL import Image
from tqdm import tqdm
import transformers
import datasets
from datasets import load_dataset
from transformers import (VisionEncoderDecoderModel, GPT2TokenizerFast, ViTImageProcessor, 
                          Seq2SeqTrainingArguments, Seq2SeqTrainer, pipeline, EvalPrediction)
from torch.utils.data import DataLoader
import torch
from torch.optim import AdamW
import evaluate

from pycocoevalcap.spice.spice import Spice


import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import os

In [4]:
import os

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

world_size = torch.cuda.device_count()
print(world_size)

4


In [5]:
print(transformers.file_utils.default_cache_path)
print(datasets.config.HF_DATASETS_CACHE)

/nfs/stak/users/arulmozg/hpc-share/huggingface/hub
/nfs/stak/users/arulmozg/hpc-share/huggingface


In [6]:
def ddp_setup(rank, world_size):
    # Initialize the process group
    dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
    
def cleanup():
    dist.destroy_process_group()


In [7]:
def is_url(string):
    try:
        result = parse.urlparse(string)
        return all([result.scheme, result.netloc, result.path])
    except:
        return False

def load_image(image_path):
    if is_url(image_path):
        return Image.open(requests.get(image_path, stream=True).raw)
    elif os.path.exists(image_path):
        return Image.open(image_path)

def preprocess(items):
    pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
    targets = tokenizer([sentence["raw"] for sentence in items["sentences"]],
                        max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
    return {'pixel_values': pixel_values, 'labels': targets["input_ids"]}

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])
    }

In [8]:

max_length = 32
coco_dataset_ratio = 50
train_ds = load_dataset("HuggingFaceM4/COCO", split=f"train[:{coco_dataset_ratio}%]")
valid_ds = load_dataset("HuggingFaceM4/COCO", split=f"validation[:{coco_dataset_ratio}%]")
test_ds = load_dataset("HuggingFaceM4/COCO", split="test")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [9]:

train_ds = train_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=world_size)
valid_ds = valid_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=world_size)
test_ds = test_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=world_size)

train_dataset = train_ds.with_transform(preprocess)
valid_dataset = valid_ds.with_transform(preprocess)
test_dataset = test_ds.with_transform(preprocess)

In [10]:
def get_data_loaders(batch_size, rank, world_size):

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)

    valid_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, sampler=valid_sampler)

    return train_loader, valid_loader


In [11]:
def train_step(model, batch, optimizer, pixel_values, labels):
    optimizer.zero_grad()
    outputs = model(pixel_values=pixel_values, labels=labels)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluate_model(model, valid_loader, rank):
    model.eval()
    valid_loss = 0.0
    predictions = []
    labels = []

    with torch.no_grad():
        for batch in valid_loader:
            pixel_values = batch["pixel_values"].to(rank)
            label_ids = batch["labels"].to(rank)
            outputs = model(pixel_values=pixel_values, labels=label_ids)
            loss = outputs.loss
            valid_loss += loss.item()

            logits = outputs.logits.detach().cpu()
            predictions.extend(logits.argmax(dim=-1).tolist())
            labels.extend(label_ids.tolist())

    avg_val_loss = valid_loss / len(valid_loader)
    return avg_val_loss, predictions, labels



In [12]:
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

def compute_metrics(eval_pred, tokenizer):
    preds = eval_pred.predictions
    labels = eval_pred.label_ids
    
    # Decode predictions and labels
    pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE
    rouge_result = rouge.compute(predictions=pred_str, references=labels_str)
    rouge_result = {k: round(v * 100, 4) for k, v in rouge_result.items()}

    # Compute BLEU
    bleu_result = bleu.compute(predictions=pred_str, references=labels_str)
    bleu_score = round(bleu_result["bleu"] * 100, 4)

    # Compute SPICE
    spice_result = compute_spice(pred_str, labels_str)
    spice_score = round(spice_result["spice"], 4)

    return {
        "rouge1": rouge_result.get("rouge1", 0),
        "rouge2": rouge_result.get("rouge2", 0),
        "rougeL": rouge_result.get("rougeL", 0),
        "bleu": bleu_score,
        "spice": spice_score
    }


def compute_spice(predictions, references):
    spice = Spice()
    
    # Create dictionaries for gts and res
    res = {i: [pred] for i, pred in enumerate(predictions)}
    gts = {i: [ref] for i, ref in enumerate(references)}
    
    # Compute SPICE score
    average_score, scores = spice.compute_score(gts, res)
    return {"spice": average_score}

def get_evaluation_metrics(model, dataset):
    model.eval()
    dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=batch_size)
    n_test_steps = len(dataloader)
    predictions, labels = [], []
    test_loss = 0.0
    for batch in tqdm(dataloader, "Evaluating"):
        pixel_values = batch["pixel_values"]
        label_ids = batch["labels"]
        outputs = model(pixel_values=pixel_values, labels=label_ids)
        loss = outputs.loss
        test_loss += loss.item()
        logits = outputs.logits.detach().cpu()
        predictions.extend(logits.argmax(dim=-1).tolist())
        labels.extend(label_ids.tolist())
    eval_prediction = EvalPrediction(predictions=predictions, label_ids=labels)
    metrics = compute_metrics(eval_prediction)
    metrics["test_loss"] = test_loss / n_test_steps
    return metrics

def get_caption(model, image_processor, tokenizer, image_path):
    image = load_image(image_path)
    img = image_processor(image, return_tensors="pt").to(device)
    output = model.generate(**img)
    caption = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    return caption

def show_image_and_captions(url):
    display(load_image(url))
    our_caption = get_caption(best_model, image_processor, tokenizer, url)
    pipeline_caption = get_caption(image_captioner.model, image_processor, tokenizer, url)
    print(f"Our caption: {our_caption}")
    print(f"Abdou/vit-swin-base-224-gpt2-image-captioning caption: {pipeline_caption}")

In [13]:
'''# File to log training metrics
log_file = open("training_logs.txt", "w")

# number of training steps
n_train_steps = num_epochs * len(train_dataset_loader)
# number of validation steps
n_valid_steps = len(valid_dataset_loader)
# current training step
current_step = 0
# logging, eval & save steps
save_steps = 1000'''


'# File to log training metrics\nlog_file = open("training_logs.txt", "w")\n\n# number of training steps\nn_train_steps = num_epochs * len(train_dataset_loader)\n# number of validation steps\nn_valid_steps = len(valid_dataset_loader)\n# current training step\ncurrent_step = 0\n# logging, eval & save steps\nsave_steps = 1000'

In [14]:
def train_model(rank, world_size, train_dataloader, valid_loader, batch_size=32, num_epochs=10, learning_rate=1e-3, patience=3, weight_decay=1e-5):
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    scheduler = ReduceLROnPlateau(optimizer, patience=patience)
    best_val_loss = float('inf')
    stop_counter = 0  # Counter for early stopping

    train_losses = []
    bleu_values = []
    rouge1_values = []
    rouge2_values = []
    rougeL_values = []
    spice_values = []
    val_losses = []
    prev_lr = optimizer.param_groups[0]['lr']  # Get initial learning rate

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        train_dataloader_iter = tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
        for i, data in enumerate(train_dataloader_iter):
            pixel_values = batch["pixel_values"].to(rank)
            labels = batch["labels"].to(rank)
            loss = train_step(model, batch, optimizer, pixel_values, labels)
            train_loss += loss

        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        if rank == 0:
            avg_val_loss, predictions, labels = evaluate_model(model, valid_loader, rank)
            val_losses.append(avg_val_loss)


            print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        
            eval_prediction = EvalPrediction(predictions=predictions, label_ids=labels)
            metrics = compute_metrics(eval_prediction)
            
            bleu_values.append(metrics['bleu'])
            rouge1_values.append(metrics['rouge'])
            rouge2_values.append(metrics['rouge2'])
            rougeL_values.append(metrics['rougeL'])
            spice.append(metrics['spice'])
            
            print(f"\n BLEU: {metrics['bleu']:.4f}, " +
                f"ROUGE-1: {metrics['rouge1']:.4f}, ROUGE-2: {metrics['rouge2']:.4f}, ROUGE-L: {metrics['rougeL']:.4f}," +
                  f"SPICE: {metrics['spice']:.4f}\n")

            scheduler.step(avg_val_loss)
            current_lr = optimizer.param_groups[0]['lr']  # Get current learning rate
    
            # Print the learning rate only if it changes
            if current_lr != prev_lr:
                print("Learning rate changed to:", current_lr)
                prev_lr = current_lr  # Update previous learning rate
    
            if avg_val_loss < best_val_loss:
                torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pt')
                checkpoint_path = f"./image-captioning/checkpoint-{epoch + 1}"
                model.save_pretrained(checkpoint_path)
                tokenizer.save_pretrained(checkpoint_path)
                image_processor.save_pretrained(checkpoint_path)
    
                best_val_loss = avg_val_loss
                stop_counter = 0
            else:
                stop_counter += 1
    
            # Early stopping
            if stop_counter >= patience:
                print("Early stopping...")
                break

    print("\n---Finished Training---\n")


In [15]:
def train(rank, world_size):
    ddp_setup(rank, world_size)

    train_loader, valid_loader = get_data_loaders(batch_size, rank, world_size)

    # Initialize model
    encoder_model = "microsoft/swin-base-patch4-window7-224-in22k"
    decoder_model = "gpt2"
    model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_model, decoder_model).to(rank)
    model = DDP(model, device_ids=[rank])
    tokenizer = GPT2TokenizerFast.from_pretrained(decoder_model)
    image_processor = ViTImageProcessor.from_pretrained(encoder_model)

    
    if "gpt2" in decoder_model:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.decoder_start_token_id = tokenizer.bos_token_id
    else:
        model.config.decoder_start_token_id = tokenizer.cls_token_id
        model.config.pad_token_id = tokenizer.pad_token_id

    train_model(rank, world_size, train_loader, valid_loader, batch_size=32, num_epochs=10, learning_rate=1e-3, patience=3, weight_decay=1e-5)
    cleanup()

In [16]:
# Main function to launch training on all GPUs
def main():
    world_size = torch.cuda.device_count()
    mp.spawn(train, args=(world_size), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/nfs/stak/users/arulmozg/hpc-share/miniconda3/envs/genalt/lib/python3.11/multiprocessing/spawn.py", line 120, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/stak/users/arulmozg/hpc-share/miniconda3/envs/genalt/lib/python3.11/multiprocessing/spawn.py", line 130, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'train' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/nfs/stak/users/arulmozg/hpc-share/miniconda3/envs/genalt/lib/python3.11/multiprocessing/spawn.py", line 120, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/stak/users/arulmozg/hpc-share/miniconda3/envs/genalt/lib/python3.11/multiprocessing/spawn.py", line 130, in _m

ProcessExitedException: process 3 terminated with exit code 1

In [None]:
# Save the best model
best_model = VisionEncoderDecoderModel.from_pretrained(f"./image-captioning/checkpoint-{best_checkpoint}").to(device)

# Evaluate on the test dataset
metrics = get_evaluation_metrics(best_model, test_dataset)
print(metrics)

In [None]:

# Perform inference
image_captioner = pipeline("image-to-text", model="Abdou/vit-swin-base-224-gpt2-image-captioning")
image_captioner.model = image_captioner.model.to(device)

In [None]:

show_image_and_captions("http://images.cocodataset.org/test-stuff2017/000000000001.jpg")
# show_image_and_captions("http://images.cocodataset.org/test-stuff2017/000000000019.jpg")
# show_image_and_captions("http://images.cocodataset.org/test-stuff2017/000000000128.jpg")
# show_image_and_captions("http://images.cocodataset.org/test-stuff2017/000000003072.jpg")
# show_image_and_captions("http://images.cocodataset.org/test-stuff2017/000000003324.jpg")
# show_image_and_captions("http://images.cocodataset.org/test-stuff2017/000000003720.jpg")