In [2]:
import torch
import pandas as pd
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer


OSError: [WinError 1114] A dynamic link library (DLL) initialization routine failed. Error loading "c:\Users\Home\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\lib\c10.dll" or one of its dependencies.

In [None]:
CSV_PATH = r"D:\\IML_CSV\\artemis_preprocessed_with_paths2.csv"

df = pd.read_csv(CSV_PATH)
df.head()


Device set to use cuda:0


In [None]:
model_name = "nlpconnect/vit-gpt2-image-captioning"

feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)

model.to("cpu")   # CPU training


In [None]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)


In [None]:
class ArtEmisDataset(Dataset):
    def __init__(self, df, feature_extractor, tokenizer, max_length=64):
        self.df = df
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        img = Image.open(row["img_resized_path"]).convert("RGB")
        pixel_values = self.feature_extractor(images=img, return_tensors="pt").pixel_values.squeeze()

        encoding = self.tokenizer(
            row["caption"],
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt"
        )

        return {
            "pixel_values": pixel_values,
            "labels": encoding["input_ids"].squeeze(),
        }


In [None]:
train_ds = ArtEmisDataset(train_df, feature_extractor, tokenizer)
val_ds = ArtEmisDataset(val_df, feature_extractor, tokenizer)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)  # smaller batch for CPU
val_loader = DataLoader(val_ds, batch_size=4)


KeyboardInterrupt: 

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

epochs = 1   # start with 1 epoch on CPU
device = "cpu"

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss/len(train_loader)}")


In [None]:
def generate_caption(image_path):
    img = Image.open(image_path).convert("RGB")
    pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)

    output_ids = model.generate(pixel_values, max_length=64)
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return caption


In [None]:
train_ds = ArtEmisDataset(train_df, feature_extractor, tokenizer)
val_ds = ArtEmisDataset(val_df, feature_extractor, tokenizer)

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)


In [None]:
epochs = 3

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Train Loss = {total_loss / len(train_loader):.4f}")


In [None]:
test_img = df.iloc[0]["img_resized_path"]
print(generate_caption(test_img))
